diff options
Diffstat (limited to 'net/mptcp/pm_userspace.c')
-rw-r--r-- | net/mptcp/pm_userspace.c | 115 |
1 files changed, 88 insertions, 27 deletions
diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c index d042d32beb..01b3a8f2f0 100644 --- a/net/mptcp/pm_userspace.c +++ b/net/mptcp/pm_userspace.c @@ -26,7 +26,8 @@ void mptcp_free_local_addr_list(struct mptcp_sock *msk) } static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, - struct mptcp_pm_addr_entry *entry) + struct mptcp_pm_addr_entry *entry, + bool needs_id) { DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); struct mptcp_pm_addr_entry *match = NULL; @@ -41,7 +42,7 @@ static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, spin_lock_bh(&msk->pm.lock); list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) { addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true); - if (addr_match && entry->addr.id == 0) + if (addr_match && entry->addr.id == 0 && needs_id) entry->addr.id = e->addr.id; id_match = (e->addr.id == entry->addr.id); if (addr_match && id_match) { @@ -64,7 +65,7 @@ static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, } *e = *entry; - if (!e->addr.id) + if (!e->addr.id && needs_id) e->addr.id = find_next_zero_bit(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1, 1); @@ -130,10 +131,21 @@ int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc) { - struct mptcp_pm_addr_entry new_entry; + struct mptcp_pm_addr_entry *entry = NULL, *e, new_entry; __be16 msk_sport = ((struct inet_sock *) inet_sk((struct sock *)msk))->inet_sport; + spin_lock_bh(&msk->pm.lock); + list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) { + if (mptcp_addresses_equal(&e->addr, skc, false)) { + entry = e; + break; + } + } + spin_unlock_bh(&msk->pm.lock); + if (entry) + return entry->addr.id; + memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry)); new_entry.addr = *skc; new_entry.addr.id = 0; @@ -142,16 +154,17 @@ int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, if (new_entry.addr.port == msk_sport) new_entry.addr.port = 0; - return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry); + return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry, true); } -int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) +int mptcp_pm_nl_announce_doit(struct sk_buff *skb, struct genl_info *info) { struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct mptcp_pm_addr_entry addr_val; struct mptcp_sock *msk; int err = -EINVAL; + struct sock *sk; u32 token_val; if (!addr || !token) { @@ -167,6 +180,8 @@ int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) return err; } + sk = (struct sock *)msk; + if (!mptcp_pm_is_userspace(msk)) { GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); goto announce_err; @@ -184,13 +199,13 @@ int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) goto announce_err; } - err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val); + err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val, false); if (err < 0) { GENL_SET_ERR_MSG(info, "did not match address and id"); goto announce_err; } - lock_sock((struct sock *)msk); + lock_sock(sk); spin_lock_bh(&msk->pm.lock); if (mptcp_pm_alloc_anno_list(msk, &addr_val.addr)) { @@ -200,15 +215,49 @@ int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) } spin_unlock_bh(&msk->pm.lock); - release_sock((struct sock *)msk); + release_sock(sk); err = 0; announce_err: - sock_put((struct sock *)msk); + sock_put(sk); return err; } -int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) +static int mptcp_userspace_pm_remove_id_zero_address(struct mptcp_sock *msk, + struct genl_info *info) +{ + struct mptcp_rm_list list = { .nr = 0 }; + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + bool has_id_0 = false; + int err = -EINVAL; + + lock_sock(sk); + mptcp_for_each_subflow(msk, subflow) { + if (READ_ONCE(subflow->local_id) == 0) { + has_id_0 = true; + break; + } + } + if (!has_id_0) { + GENL_SET_ERR_MSG(info, "address with id 0 not found"); + goto remove_err; + } + + list.ids[list.nr++] = 0; + + spin_lock_bh(&msk->pm.lock); + mptcp_pm_remove_addr(msk, &list); + spin_unlock_bh(&msk->pm.lock); + + err = 0; + +remove_err: + release_sock(sk); + return err; +} + +int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info) { struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID]; @@ -217,6 +266,7 @@ int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) struct mptcp_sock *msk; LIST_HEAD(free_list); int err = -EINVAL; + struct sock *sk; u32 token_val; u8 id_val; @@ -234,12 +284,19 @@ int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) return err; } + sk = (struct sock *)msk; + if (!mptcp_pm_is_userspace(msk)) { GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); goto remove_err; } - lock_sock((struct sock *)msk); + if (id_val == 0) { + err = mptcp_userspace_pm_remove_id_zero_address(msk, info); + goto remove_err; + } + + lock_sock(sk); list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { if (entry->addr.id == id_val) { @@ -250,7 +307,7 @@ int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) if (!match) { GENL_SET_ERR_MSG(info, "address with specified id not found"); - release_sock((struct sock *)msk); + release_sock(sk); goto remove_err; } @@ -258,19 +315,19 @@ int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) mptcp_pm_remove_addrs(msk, &free_list); - release_sock((struct sock *)msk); + release_sock(sk); list_for_each_entry_safe(match, entry, &free_list, list) { - sock_kfree_s((struct sock *)msk, match, sizeof(*match)); + sock_kfree_s(sk, match, sizeof(*match)); } err = 0; remove_err: - sock_put((struct sock *)msk); + sock_put(sk); return err; } -int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) +int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info) { struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; @@ -296,6 +353,8 @@ int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) return err; } + sk = (struct sock *)msk; + if (!mptcp_pm_is_userspace(msk)) { GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); goto create_err; @@ -313,8 +372,6 @@ int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) goto create_err; } - sk = (struct sock *)msk; - if (!mptcp_pm_addr_families_match(sk, &addr_l, &addr_r)) { GENL_SET_ERR_MSG(info, "families mismatch"); err = -EINVAL; @@ -322,7 +379,7 @@ int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) } local.addr = addr_l; - err = mptcp_userspace_pm_append_new_local_addr(msk, &local); + err = mptcp_userspace_pm_append_new_local_addr(msk, &local, false); if (err < 0) { GENL_SET_ERR_MSG(info, "did not match address and id"); goto create_err; @@ -342,7 +399,7 @@ int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) spin_unlock_bh(&msk->pm.lock); create_err: - sock_put((struct sock *)msk); + sock_put(sk); return err; } @@ -394,7 +451,7 @@ static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk, return NULL; } -int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) +int mptcp_pm_nl_subflow_destroy_doit(struct sk_buff *skb, struct genl_info *info) { struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; @@ -419,6 +476,8 @@ int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) return err; } + sk = (struct sock *)msk; + if (!mptcp_pm_is_userspace(msk)) { GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); goto destroy_err; @@ -448,7 +507,6 @@ int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) goto destroy_err; } - sk = (struct sock *)msk; lock_sock(sk); ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r); if (ssk) { @@ -468,7 +526,7 @@ int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) release_sock(sk); destroy_err: - sock_put((struct sock *)msk); + sock_put(sk); return err; } @@ -478,6 +536,7 @@ int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token, { struct mptcp_sock *msk; int ret = -EINVAL; + struct sock *sk; u32 token_val; token_val = nla_get_u32(token); @@ -486,6 +545,8 @@ int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token, if (!msk) return ret; + sk = (struct sock *)msk; + if (!mptcp_pm_is_userspace(msk)) goto set_flags_err; @@ -493,11 +554,11 @@ int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token, rem->addr.family == AF_UNSPEC) goto set_flags_err; - lock_sock((struct sock *)msk); + lock_sock(sk); ret = mptcp_pm_nl_mp_prio_send_ack(msk, &loc->addr, &rem->addr, bkup); - release_sock((struct sock *)msk); + release_sock(sk); set_flags_err: - sock_put((struct sock *)msk); + sock_put(sk); return ret; } |