summaryrefslogtreecommitdiffstats
path: root/net/mpls/af_mpls.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/mpls/af_mpls.c')
-rw-r--r--net/mpls/af_mpls.c210
1 files changed, 139 insertions, 71 deletions
diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c
index 06ffafde70da..5928d22ba9c8 100644
--- a/net/mpls/af_mpls.c
+++ b/net/mpls/af_mpls.c
@@ -24,6 +24,9 @@
#include <net/nexthop.h>
#include "internal.h"
+/* max memory we will use for mpls_route */
+#define MAX_MPLS_ROUTE_MEM 4096
+
/* Maximum number of labels to look ahead at when selecting a path of
* a multipath route
*/
@@ -60,10 +63,7 @@ EXPORT_SYMBOL_GPL(mpls_output_possible);
static u8 *__mpls_nh_via(struct mpls_route *rt, struct mpls_nh *nh)
{
- u8 *nh0_via = PTR_ALIGN((u8 *)&rt->rt_nh[rt->rt_nhn], VIA_ALEN_ALIGN);
- int nh_index = nh - rt->rt_nh;
-
- return nh0_via + rt->rt_max_alen * nh_index;
+ return (u8 *)nh + rt->rt_via_offset;
}
static const u8 *mpls_nh_via(const struct mpls_route *rt,
@@ -189,21 +189,32 @@ static u32 mpls_multipath_hash(struct mpls_route *rt, struct sk_buff *skb)
return hash;
}
+static struct mpls_nh *mpls_get_nexthop(struct mpls_route *rt, u8 index)
+{
+ return (struct mpls_nh *)((u8 *)rt->rt_nh + index * rt->rt_nh_size);
+}
+
+/* number of alive nexthops (rt->rt_nhn_alive) and the flags for
+ * a next hop (nh->nh_flags) are modified by netdev event handlers.
+ * Since those fields can change at any moment, use READ_ONCE to
+ * access both.
+ */
static struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
struct sk_buff *skb)
{
- int alive = ACCESS_ONCE(rt->rt_nhn_alive);
u32 hash = 0;
int nh_index = 0;
int n = 0;
+ u8 alive;
/* No need to look further into packet if there's only
* one path
*/
if (rt->rt_nhn == 1)
- goto out;
+ return rt->rt_nh;
- if (alive <= 0)
+ alive = READ_ONCE(rt->rt_nhn_alive);
+ if (alive == 0)
return NULL;
hash = mpls_multipath_hash(rt, skb);
@@ -211,7 +222,9 @@ static struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
if (alive == rt->rt_nhn)
goto out;
for_nexthops(rt) {
- if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
+ unsigned int nh_flags = READ_ONCE(nh->nh_flags);
+
+ if (nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
continue;
if (n == nh_index)
return nh;
@@ -219,7 +232,7 @@ static struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
} endfor_nexthops(rt);
out:
- return &rt->rt_nh[nh_index];
+ return mpls_get_nexthop(rt, nh_index);
}
static bool mpls_egress(struct net *net, struct mpls_route *rt,
@@ -458,20 +471,27 @@ struct mpls_route_config {
int rc_mp_len;
};
-static struct mpls_route *mpls_rt_alloc(int num_nh, u8 max_alen)
+/* all nexthops within a route have the same size based on max
+ * number of labels and max via length for a hop
+ */
+static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
{
- u8 max_alen_aligned = ALIGN(max_alen, VIA_ALEN_ALIGN);
+ u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen);
struct mpls_route *rt;
+ size_t size;
- rt = kzalloc(ALIGN(sizeof(*rt) + num_nh * sizeof(*rt->rt_nh),
- VIA_ALEN_ALIGN) +
- num_nh * max_alen_aligned,
- GFP_KERNEL);
- if (rt) {
- rt->rt_nhn = num_nh;
- rt->rt_nhn_alive = num_nh;
- rt->rt_max_alen = max_alen_aligned;
- }
+ size = sizeof(*rt) + num_nh * nh_size;
+ if (size > MAX_MPLS_ROUTE_MEM)
+ return ERR_PTR(-EINVAL);
+
+ rt = kzalloc(size, GFP_KERNEL);
+ if (!rt)
+ return ERR_PTR(-ENOMEM);
+
+ rt->rt_nhn = num_nh;
+ rt->rt_nhn_alive = num_nh;
+ rt->rt_nh_size = nh_size;
+ rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels);
return rt;
}
@@ -676,9 +696,6 @@ static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg,
return -ENOMEM;
err = -EINVAL;
- /* Ensure only a supported number of labels are present */
- if (cfg->rc_output_labels > MAX_NEW_LABELS)
- goto errout;
nh->nh_labels = cfg->rc_output_labels;
for (i = 0; i < nh->nh_labels; i++)
@@ -703,7 +720,7 @@ errout:
static int mpls_nh_build(struct net *net, struct mpls_route *rt,
struct mpls_nh *nh, int oif, struct nlattr *via,
- struct nlattr *newdst)
+ struct nlattr *newdst, u8 max_labels)
{
int err = -ENOMEM;
@@ -711,7 +728,7 @@ static int mpls_nh_build(struct net *net, struct mpls_route *rt,
goto errout;
if (newdst) {
- err = nla_get_labels(newdst, MAX_NEW_LABELS,
+ err = nla_get_labels(newdst, max_labels,
&nh->nh_labels, nh->nh_label);
if (err)
goto errout;
@@ -736,22 +753,20 @@ errout:
return err;
}
-static int mpls_count_nexthops(struct rtnexthop *rtnh, int len,
- u8 cfg_via_alen, u8 *max_via_alen)
+static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len,
+ u8 cfg_via_alen, u8 *max_via_alen,
+ u8 *max_labels)
{
- int nhs = 0;
int remaining = len;
-
- if (!rtnh) {
- *max_via_alen = cfg_via_alen;
- return 1;
- }
+ u8 nhs = 0;
*max_via_alen = 0;
+ *max_labels = 0;
while (rtnh_ok(rtnh, remaining)) {
struct nlattr *nla, *attrs = rtnh_attrs(rtnh);
int attrlen;
+ u8 n_labels = 0;
attrlen = rtnh_attrlen(rtnh);
nla = nla_find(attrs, attrlen, RTA_VIA);
@@ -765,7 +780,20 @@ static int mpls_count_nexthops(struct rtnexthop *rtnh, int len,
via_alen);
}
+ nla = nla_find(attrs, attrlen, RTA_NEWDST);
+ if (nla &&
+ nla_get_labels(nla, MAX_NEW_LABELS, &n_labels, NULL) != 0)
+ return 0;
+
+ *max_labels = max_t(u8, *max_labels, n_labels);
+
+ /* number of nexthops is tracked by a u8.
+ * Check for overflow.
+ */
+ if (nhs == 255)
+ return 0;
nhs++;
+
rtnh = rtnh_next(rtnh, &remaining);
}
@@ -774,13 +802,13 @@ static int mpls_count_nexthops(struct rtnexthop *rtnh, int len,
}
static int mpls_nh_build_multi(struct mpls_route_config *cfg,
- struct mpls_route *rt)
+ struct mpls_route *rt, u8 max_labels)
{
struct rtnexthop *rtnh = cfg->rc_mp;
struct nlattr *nla_via, *nla_newdst;
int remaining = cfg->rc_mp_len;
- int nhs = 0;
int err = 0;
+ u8 nhs = 0;
change_nexthops(rt) {
int attrlen;
@@ -807,7 +835,8 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg,
}
err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh,
- rtnh->rtnh_ifindex, nla_via, nla_newdst);
+ rtnh->rtnh_ifindex, nla_via, nla_newdst,
+ max_labels);
if (err)
goto errout;
@@ -834,7 +863,8 @@ static int mpls_route_add(struct mpls_route_config *cfg)
int err = -EINVAL;
u8 max_via_alen;
unsigned index;
- int nhs;
+ u8 max_labels;
+ u8 nhs;
index = cfg->rc_label;
@@ -872,22 +902,32 @@ static int mpls_route_add(struct mpls_route_config *cfg)
goto errout;
err = -EINVAL;
- nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
- cfg->rc_via_alen, &max_via_alen);
+ if (cfg->rc_mp) {
+ nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
+ cfg->rc_via_alen, &max_via_alen,
+ &max_labels);
+ } else {
+ max_via_alen = cfg->rc_via_alen;
+ max_labels = cfg->rc_output_labels;
+ nhs = 1;
+ }
+
if (nhs == 0)
goto errout;
err = -ENOMEM;
- rt = mpls_rt_alloc(nhs, max_via_alen);
- if (!rt)
+ rt = mpls_rt_alloc(nhs, max_via_alen, max_labels);
+ if (IS_ERR(rt)) {
+ err = PTR_ERR(rt);
goto errout;
+ }
rt->rt_protocol = cfg->rc_protocol;
rt->rt_payload_type = cfg->rc_payload_type;
rt->rt_ttl_propagate = cfg->rc_ttl_propagate;
if (cfg->rc_mp)
- err = mpls_nh_build_multi(cfg, rt);
+ err = mpls_nh_build_multi(cfg, rt, max_labels);
else
err = mpls_nh_build_from_cfg(cfg, rt);
if (err)
@@ -1302,8 +1342,7 @@ static void mpls_ifdown(struct net_device *dev, int event)
{
struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev);
- unsigned int nh_flags = RTNH_F_DEAD | RTNH_F_LINKDOWN;
- unsigned int alive, deleted;
+ u8 alive, deleted;
unsigned index;
platform_label = rtnl_dereference(net->mpls.platform_label);
@@ -1316,22 +1355,27 @@ static void mpls_ifdown(struct net_device *dev, int event)
alive = 0;
deleted = 0;
change_nexthops(rt) {
+ unsigned int nh_flags = nh->nh_flags;
+
if (rtnl_dereference(nh->nh_dev) != dev)
goto next;
switch (event) {
case NETDEV_DOWN:
case NETDEV_UNREGISTER:
- nh->nh_flags |= RTNH_F_DEAD;
+ nh_flags |= RTNH_F_DEAD;
/* fall through */
case NETDEV_CHANGE:
- nh->nh_flags |= RTNH_F_LINKDOWN;
+ nh_flags |= RTNH_F_LINKDOWN;
break;
}
if (event == NETDEV_UNREGISTER)
RCU_INIT_POINTER(nh->nh_dev, NULL);
+
+ if (nh->nh_flags != nh_flags)
+ WRITE_ONCE(nh->nh_flags, nh_flags);
next:
- if (!(nh->nh_flags & nh_flags))
+ if (!(nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)))
alive++;
if (!rtnl_dereference(nh->nh_dev))
deleted++;
@@ -1345,12 +1389,12 @@ next:
}
}
-static void mpls_ifup(struct net_device *dev, unsigned int nh_flags)
+static void mpls_ifup(struct net_device *dev, unsigned int flags)
{
struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev);
unsigned index;
- int alive;
+ u8 alive;
platform_label = rtnl_dereference(net->mpls.platform_label);
for (index = 0; index < net->mpls.platform_labels; index++) {
@@ -1361,20 +1405,22 @@ static void mpls_ifup(struct net_device *dev, unsigned int nh_flags)
alive = 0;
change_nexthops(rt) {
+ unsigned int nh_flags = nh->nh_flags;
struct net_device *nh_dev =
rtnl_dereference(nh->nh_dev);
- if (!(nh->nh_flags & nh_flags)) {
+ if (!(nh_flags & flags)) {
alive++;
continue;
}
if (nh_dev != dev)
continue;
alive++;
- nh->nh_flags &= ~nh_flags;
+ nh_flags &= ~flags;
+ WRITE_ONCE(nh->nh_flags, flags);
} endfor_nexthops(rt);
- ACCESS_ONCE(rt->rt_nhn_alive) = alive;
+ WRITE_ONCE(rt->rt_nhn_alive, alive);
}
}
@@ -1495,16 +1541,18 @@ int nla_put_labels(struct sk_buff *skb, int attrtype,
EXPORT_SYMBOL_GPL(nla_put_labels);
int nla_get_labels(const struct nlattr *nla,
- u32 max_labels, u8 *labels, u32 label[])
+ u8 max_labels, u8 *labels, u32 label[])
{
unsigned len = nla_len(nla);
- unsigned nla_labels;
struct mpls_shim_hdr *nla_label;
+ u8 nla_labels;
bool bos;
int i;
- /* len needs to be an even multiple of 4 (the label size) */
- if (len & 3)
+ /* len needs to be an even multiple of 4 (the label size). Number
+ * of labels is a u8 so check for overflow.
+ */
+ if (len & 3 || len / 4 > 255)
return -EINVAL;
/* Limit the number of new labels allowed */
@@ -1512,6 +1560,10 @@ int nla_get_labels(const struct nlattr *nla,
if (nla_labels > max_labels)
return -EINVAL;
+ /* when label == NULL, caller wants number of labels */
+ if (!label)
+ goto out;
+
nla_label = nla_data(nla);
bos = true;
for (i = nla_labels - 1; i >= 0; i--, bos = false) {
@@ -1535,6 +1587,7 @@ int nla_get_labels(const struct nlattr *nla,
label[i] = dec.label;
}
+out:
*labels = nla_labels;
return 0;
}
@@ -1596,7 +1649,6 @@ static int rtm_to_route_config(struct sk_buff *skb, struct nlmsghdr *nlh,
err = -EINVAL;
rtm = nlmsg_data(nlh);
- memset(cfg, 0, sizeof(*cfg));
if (rtm->rtm_family != AF_MPLS)
goto errout;
@@ -1695,27 +1747,43 @@ errout:
static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
{
- struct mpls_route_config cfg;
+ struct mpls_route_config *cfg;
int err;
- err = rtm_to_route_config(skb, nlh, &cfg);
+ cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
+ if (!cfg)
+ return -ENOMEM;
+
+ err = rtm_to_route_config(skb, nlh, cfg);
if (err < 0)
- return err;
+ goto out;
+
+ err = mpls_route_del(cfg);
+out:
+ kfree(cfg);
- return mpls_route_del(&cfg);
+ return err;
}
static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
{
- struct mpls_route_config cfg;
+ struct mpls_route_config *cfg;
int err;
- err = rtm_to_route_config(skb, nlh, &cfg);
+ cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
+ if (!cfg)
+ return -ENOMEM;
+
+ err = rtm_to_route_config(skb, nlh, cfg);
if (err < 0)
- return err;
+ goto out;
- return mpls_route_add(&cfg);
+ err = mpls_route_add(cfg);
+out:
+ kfree(cfg);
+
+ return err;
}
static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
@@ -1772,8 +1840,8 @@ static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
} else {
struct rtnexthop *rtnh;
struct nlattr *mp;
- int dead = 0;
- int linkdown = 0;
+ u8 linkdown = 0;
+ u8 dead = 0;
mp = nla_nest_start(skb, RTA_MULTIPATH);
if (!mp)
@@ -1944,8 +2012,8 @@ static int resize_platform_label_table(struct net *net, size_t limit)
/* In case the predefined labels need to be populated */
if (limit > MPLS_LABEL_IPV4NULL) {
struct net_device *lo = net->loopback_dev;
- rt0 = mpls_rt_alloc(1, lo->addr_len);
- if (!rt0)
+ rt0 = mpls_rt_alloc(1, lo->addr_len, 0);
+ if (IS_ERR(rt0))
goto nort0;
RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
rt0->rt_protocol = RTPROT_KERNEL;
@@ -1958,8 +2026,8 @@ static int resize_platform_label_table(struct net *net, size_t limit)
}
if (limit > MPLS_LABEL_IPV6NULL) {
struct net_device *lo = net->loopback_dev;
- rt2 = mpls_rt_alloc(1, lo->addr_len);
- if (!rt2)
+ rt2 = mpls_rt_alloc(1, lo->addr_len, 0);
+ if (IS_ERR(rt2))
goto nort2;
RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);
rt2->rt_protocol = RTPROT_KERNEL;