summaryrefslogtreecommitdiffstats
path: root/src/optimize.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/optimize.c1406
1 files changed, 1406 insertions, 0 deletions
diff --git a/src/optimize.c b/src/optimize.c
new file mode 100644
index 0000000..27e0ffe
--- /dev/null
+++ b/src/optimize.c
@@ -0,0 +1,1406 @@
+/*
+ * Copyright (c) 2021 Pablo Neira Ayuso <pablo@netfilter.org>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 (or any
+ * later) as published by the Free Software Foundation.
+ */
+
+/* Funded through the NGI0 PET Fund established by NLnet (https://nlnet.nl)
+ * with support from the European Commission's Next Generation Internet
+ * programme.
+ */
+
+#include <nft.h>
+
+#include <errno.h>
+#include <inttypes.h>
+#include <nftables.h>
+#include <parser.h>
+#include <expression.h>
+#include <statement.h>
+#include <utils.h>
+#include <erec.h>
+#include <linux/netfilter.h>
+
+#define MAX_STMTS 32
+
+struct optimize_ctx {
+ struct stmt *stmt[MAX_STMTS];
+ uint32_t num_stmts;
+
+ struct stmt ***stmt_matrix;
+ struct rule **rule;
+ uint32_t num_rules;
+};
+
+static bool __expr_cmp(const struct expr *expr_a, const struct expr *expr_b)
+{
+ if (expr_a->etype != expr_b->etype)
+ return false;
+
+ switch (expr_a->etype) {
+ case EXPR_PAYLOAD:
+ if (expr_a->payload.base != expr_b->payload.base)
+ return false;
+ if (expr_a->payload.offset != expr_b->payload.offset)
+ return false;
+ if (expr_a->payload.desc != expr_b->payload.desc)
+ return false;
+ if (expr_a->payload.inner_desc != expr_b->payload.inner_desc)
+ return false;
+ if (expr_a->payload.tmpl != expr_b->payload.tmpl)
+ return false;
+ break;
+ case EXPR_EXTHDR:
+ if (expr_a->exthdr.desc != expr_b->exthdr.desc)
+ return false;
+ if (expr_a->exthdr.tmpl != expr_b->exthdr.tmpl)
+ return false;
+ break;
+ case EXPR_META:
+ if (expr_a->meta.key != expr_b->meta.key)
+ return false;
+ if (expr_a->meta.base != expr_b->meta.base)
+ return false;
+ break;
+ case EXPR_CT:
+ if (expr_a->ct.key != expr_b->ct.key)
+ return false;
+ if (expr_a->ct.base != expr_b->ct.base)
+ return false;
+ if (expr_a->ct.direction != expr_b->ct.direction)
+ return false;
+ if (expr_a->ct.nfproto != expr_b->ct.nfproto)
+ return false;
+ break;
+ case EXPR_RT:
+ if (expr_a->rt.key != expr_b->rt.key)
+ return false;
+ break;
+ case EXPR_SOCKET:
+ if (expr_a->socket.key != expr_b->socket.key)
+ return false;
+ if (expr_a->socket.level != expr_b->socket.level)
+ return false;
+ break;
+ case EXPR_OSF:
+ if (expr_a->osf.ttl != expr_b->osf.ttl)
+ return false;
+ if (expr_a->osf.flags != expr_b->osf.flags)
+ return false;
+ break;
+ case EXPR_XFRM:
+ if (expr_a->xfrm.key != expr_b->xfrm.key)
+ return false;
+ if (expr_a->xfrm.direction != expr_b->xfrm.direction)
+ return false;
+ break;
+ case EXPR_FIB:
+ if (expr_a->fib.flags != expr_b->fib.flags)
+ return false;
+ if (expr_a->fib.result != expr_b->fib.result)
+ return false;
+ break;
+ case EXPR_NUMGEN:
+ if (expr_a->numgen.type != expr_b->numgen.type)
+ return false;
+ if (expr_a->numgen.mod != expr_b->numgen.mod)
+ return false;
+ if (expr_a->numgen.offset != expr_b->numgen.offset)
+ return false;
+ break;
+ case EXPR_HASH:
+ if (expr_a->hash.mod != expr_b->hash.mod)
+ return false;
+ if (expr_a->hash.seed_set != expr_b->hash.seed_set)
+ return false;
+ if (expr_a->hash.seed != expr_b->hash.seed)
+ return false;
+ if (expr_a->hash.offset != expr_b->hash.offset)
+ return false;
+ if (expr_a->hash.type != expr_b->hash.type)
+ return false;
+ break;
+ case EXPR_BINOP:
+ return __expr_cmp(expr_a->left, expr_b->left);
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+static bool stmt_expr_supported(const struct expr *expr)
+{
+ switch (expr->right->etype) {
+ case EXPR_SYMBOL:
+ case EXPR_RANGE:
+ case EXPR_PREFIX:
+ case EXPR_SET:
+ case EXPR_LIST:
+ case EXPR_VALUE:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+static bool expr_symbol_set(const struct expr *expr)
+{
+ return expr->right->etype == EXPR_SYMBOL &&
+ expr->right->symtype == SYMBOL_SET;
+}
+
+static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b,
+ bool fully_compare)
+{
+ struct expr *expr_a, *expr_b;
+
+ if (stmt_a->ops->type != stmt_b->ops->type)
+ return false;
+
+ switch (stmt_a->ops->type) {
+ case STMT_EXPRESSION:
+ expr_a = stmt_a->expr;
+ expr_b = stmt_b->expr;
+
+ if (expr_a->op != expr_b->op)
+ return false;
+ if (expr_a->op != OP_IMPLICIT && expr_a->op != OP_EQ)
+ return false;
+
+ if (fully_compare) {
+ if (!stmt_expr_supported(expr_a) ||
+ !stmt_expr_supported(expr_b))
+ return false;
+
+ if (expr_symbol_set(expr_a) ||
+ expr_symbol_set(expr_b))
+ return false;
+ }
+
+ return __expr_cmp(expr_a->left, expr_b->left);
+ case STMT_COUNTER:
+ case STMT_NOTRACK:
+ break;
+ case STMT_VERDICT:
+ if (!fully_compare)
+ break;
+
+ expr_a = stmt_a->expr;
+ expr_b = stmt_b->expr;
+
+ if (expr_a->etype != expr_b->etype)
+ return false;
+
+ if (expr_a->etype == EXPR_MAP &&
+ expr_b->etype == EXPR_MAP)
+ return __expr_cmp(expr_a->map, expr_b->map);
+ break;
+ case STMT_LOG:
+ if (stmt_a->log.snaplen != stmt_b->log.snaplen ||
+ stmt_a->log.group != stmt_b->log.group ||
+ stmt_a->log.qthreshold != stmt_b->log.qthreshold ||
+ stmt_a->log.level != stmt_b->log.level ||
+ stmt_a->log.logflags != stmt_b->log.logflags ||
+ stmt_a->log.flags != stmt_b->log.flags)
+ return false;
+
+ if (!!stmt_a->log.prefix ^ !!stmt_b->log.prefix)
+ return false;
+
+ if (!stmt_a->log.prefix)
+ return true;
+
+ if (stmt_a->log.prefix->etype != EXPR_VALUE ||
+ stmt_b->log.prefix->etype != EXPR_VALUE ||
+ mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value))
+ return false;
+ break;
+ case STMT_REJECT:
+ if (stmt_a->reject.family != stmt_b->reject.family ||
+ stmt_a->reject.type != stmt_b->reject.type ||
+ stmt_a->reject.icmp_code != stmt_b->reject.icmp_code)
+ return false;
+
+ if (!!stmt_a->reject.expr ^ !!stmt_b->reject.expr)
+ return false;
+
+ if (!stmt_a->reject.expr)
+ return true;
+
+ if (__expr_cmp(stmt_a->reject.expr, stmt_b->reject.expr))
+ return false;
+ break;
+ case STMT_NAT:
+ if (stmt_a->nat.type != stmt_b->nat.type ||
+ stmt_a->nat.flags != stmt_b->nat.flags ||
+ stmt_a->nat.family != stmt_b->nat.family ||
+ stmt_a->nat.type_flags != stmt_b->nat.type_flags)
+ return false;
+
+ switch (stmt_a->nat.type) {
+ case NFT_NAT_SNAT:
+ case NFT_NAT_DNAT:
+ if ((stmt_a->nat.addr &&
+ stmt_a->nat.addr->etype != EXPR_SYMBOL &&
+ stmt_a->nat.addr->etype != EXPR_RANGE) ||
+ (stmt_b->nat.addr &&
+ stmt_b->nat.addr->etype != EXPR_SYMBOL &&
+ stmt_b->nat.addr->etype != EXPR_RANGE) ||
+ (stmt_a->nat.proto &&
+ stmt_a->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_a->nat.proto->etype != EXPR_RANGE) ||
+ (stmt_b->nat.proto &&
+ stmt_b->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_b->nat.proto->etype != EXPR_RANGE))
+ return false;
+ break;
+ case NFT_NAT_MASQ:
+ break;
+ case NFT_NAT_REDIR:
+ if ((stmt_a->nat.proto &&
+ stmt_a->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_a->nat.proto->etype != EXPR_RANGE) ||
+ (stmt_b->nat.proto &&
+ stmt_b->nat.proto->etype != EXPR_SYMBOL &&
+ stmt_b->nat.proto->etype != EXPR_RANGE))
+ return false;
+
+ /* it should be possible to infer implicit redirections
+ * such as:
+ *
+ * tcp dport 1234 redirect
+ * tcp dport 3456 redirect to :7890
+ * merge:
+ * redirect to tcp dport map { 1234 : 1234, 3456 : 7890 }
+ *
+ * currently not implemented.
+ */
+ if (fully_compare &&
+ stmt_a->nat.type == NFT_NAT_REDIR &&
+ stmt_b->nat.type == NFT_NAT_REDIR &&
+ (!!stmt_a->nat.proto ^ !!stmt_b->nat.proto))
+ return false;
+
+ break;
+ default:
+ assert(0);
+ }
+
+ return true;
+ default:
+ /* ... Merging anything else is yet unsupported. */
+ return false;
+ }
+
+ return true;
+}
+
+static bool expr_verdict_eq(const struct expr *expr_a, const struct expr *expr_b)
+{
+ if (expr_a->verdict != expr_b->verdict)
+ return false;
+ if (expr_a->chain && expr_b->chain) {
+ if (expr_a->chain->etype != expr_b->chain->etype)
+ return false;
+ if (expr_a->chain->etype == EXPR_VALUE &&
+ strcmp(expr_a->chain->identifier, expr_b->chain->identifier))
+ return false;
+ } else if (expr_a->chain || expr_b->chain) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool stmt_verdict_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+ struct expr *expr_a, *expr_b;
+
+ assert (stmt_a->ops->type == STMT_VERDICT);
+
+ expr_a = stmt_a->expr;
+ expr_b = stmt_b->expr;
+ if (expr_a->etype == EXPR_VERDICT &&
+ expr_b->etype == EXPR_VERDICT)
+ return expr_verdict_eq(expr_a, expr_b);
+
+ if (expr_a->etype == EXPR_MAP &&
+ expr_b->etype == EXPR_MAP)
+ return __expr_cmp(expr_a->map, expr_b->map);
+
+ return false;
+}
+
+static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
+{
+ bool unsupported_exists = false;
+ uint32_t i;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (ctx->stmt[i]->ops->type == STMT_INVALID)
+ unsupported_exists = true;
+
+ if (__stmt_type_eq(stmt, ctx->stmt[i], false))
+ return true;
+ }
+
+ switch (stmt->ops->type) {
+ case STMT_EXPRESSION:
+ case STMT_VERDICT:
+ case STMT_COUNTER:
+ case STMT_NOTRACK:
+ case STMT_LOG:
+ case STMT_NAT:
+ case STMT_REJECT:
+ break;
+ default:
+ /* add unsupported statement only once to statement matrix. */
+ if (unsupported_exists)
+ return true;
+ break;
+ }
+
+ return false;
+}
+
+static struct stmt_ops unsupported_stmt_ops = {
+ .type = STMT_INVALID,
+ .name = "unsupported",
+};
+
+static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
+{
+ struct stmt *stmt, *clone;
+
+ list_for_each_entry(stmt, &rule->stmts, list) {
+ if (stmt_type_find(ctx, stmt))
+ continue;
+
+ /* No refcounter available in statement objects, clone it to
+ * to store in the array of selectors.
+ */
+ clone = stmt_alloc(&internal_location, stmt->ops);
+ switch (stmt->ops->type) {
+ case STMT_EXPRESSION:
+ if (stmt->expr->op != OP_IMPLICIT &&
+ stmt->expr->op != OP_EQ) {
+ clone->ops = &unsupported_stmt_ops;
+ break;
+ }
+ if (stmt->expr->left->etype == EXPR_CONCAT) {
+ clone->ops = &unsupported_stmt_ops;
+ break;
+ }
+ /* fall-through */
+ case STMT_VERDICT:
+ clone->expr = expr_get(stmt->expr);
+ break;
+ case STMT_COUNTER:
+ case STMT_NOTRACK:
+ break;
+ case STMT_LOG:
+ memcpy(&clone->log, &stmt->log, sizeof(clone->log));
+ if (stmt->log.prefix)
+ clone->log.prefix = expr_get(stmt->log.prefix);
+ break;
+ case STMT_NAT:
+ if ((stmt->nat.addr &&
+ stmt->nat.addr->etype == EXPR_MAP) ||
+ (stmt->nat.proto &&
+ stmt->nat.proto->etype == EXPR_MAP)) {
+ clone->ops = &unsupported_stmt_ops;
+ break;
+ }
+ clone->nat.type = stmt->nat.type;
+ clone->nat.family = stmt->nat.family;
+ if (stmt->nat.addr)
+ clone->nat.addr = expr_clone(stmt->nat.addr);
+ if (stmt->nat.proto)
+ clone->nat.proto = expr_clone(stmt->nat.proto);
+ clone->nat.flags = stmt->nat.flags;
+ clone->nat.type_flags = stmt->nat.type_flags;
+ break;
+ case STMT_REJECT:
+ if (stmt->reject.expr)
+ clone->reject.expr = expr_get(stmt->reject.expr);
+ clone->reject.type = stmt->reject.type;
+ clone->reject.icmp_code = stmt->reject.icmp_code;
+ clone->reject.family = stmt->reject.family;
+ break;
+ default:
+ clone->ops = &unsupported_stmt_ops;
+ break;
+ }
+
+ ctx->stmt[ctx->num_stmts++] = clone;
+ if (ctx->num_stmts >= MAX_STMTS)
+ return -1;
+ }
+
+ return 0;
+}
+
+static int unsupported_in_stmt_matrix(const struct optimize_ctx *ctx)
+{
+ uint32_t i;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (ctx->stmt[i]->ops->type == STMT_INVALID)
+ return i;
+ }
+ /* this should not happen. */
+ return -1;
+}
+
+static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *stmt)
+{
+ uint32_t i;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (__stmt_type_eq(stmt, ctx->stmt[i], false))
+ return i;
+ }
+
+ return -1;
+}
+
+static struct stmt unsupported_stmt = {
+ .ops = &unsupported_stmt_ops,
+};
+
+static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx,
+ struct rule *rule, uint32_t *i)
+{
+ struct stmt *stmt;
+ int k;
+
+ list_for_each_entry(stmt, &rule->stmts, list) {
+ k = cmd_stmt_find_in_stmt_matrix(ctx, stmt);
+ if (k < 0) {
+ k = unsupported_in_stmt_matrix(ctx);
+ assert(k >= 0);
+ ctx->stmt_matrix[*i][k] = &unsupported_stmt;
+ continue;
+ }
+ ctx->stmt_matrix[*i][k] = stmt;
+ }
+ ctx->rule[(*i)++] = rule;
+}
+
+static int stmt_verdict_find(const struct optimize_ctx *ctx)
+{
+ uint32_t i;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (ctx->stmt[i]->ops->type != STMT_VERDICT)
+ continue;
+
+ return i;
+ }
+
+ return -1;
+}
+
+struct merge {
+ /* interval of rules to be merged */
+ uint32_t rule_from;
+ uint32_t num_rules;
+ /* statements to be merged (index relative to statement matrix) */
+ uint32_t stmt[MAX_STMTS];
+ uint32_t num_stmts;
+};
+
+static void merge_expr_stmts(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge,
+ struct stmt *stmt_a)
+{
+ struct expr *expr_a, *expr_b, *set, *elem;
+ struct stmt *stmt_b;
+ uint32_t i;
+
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ expr_a = stmt_a->expr->right;
+ elem = set_elem_expr_alloc(&internal_location, expr_get(expr_a));
+ compound_expr_add(set, elem);
+
+ for (i = from + 1; i <= to; i++) {
+ stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
+ expr_b = stmt_b->expr->right;
+ elem = set_elem_expr_alloc(&internal_location, expr_get(expr_b));
+ compound_expr_add(set, elem);
+ }
+
+ expr_free(stmt_a->expr->right);
+ stmt_a->expr->right = set;
+}
+
+static void merge_vmap(const struct optimize_ctx *ctx,
+ struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+ struct expr *mappings, *mapping, *expr;
+
+ mappings = stmt_b->expr->mappings;
+ list_for_each_entry(expr, &mappings->expressions, list) {
+ mapping = expr_clone(expr);
+ compound_expr_add(stmt_a->expr->mappings, mapping);
+ }
+}
+
+static void merge_verdict_stmts(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge,
+ struct stmt *stmt_a)
+{
+ struct stmt *stmt_b;
+ uint32_t i;
+
+ for (i = from + 1; i <= to; i++) {
+ stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
+ switch (stmt_b->ops->type) {
+ case STMT_VERDICT:
+ switch (stmt_b->expr->etype) {
+ case EXPR_MAP:
+ merge_vmap(ctx, stmt_a, stmt_b);
+ break;
+ default:
+ assert(0);
+ }
+ break;
+ default:
+ assert(0);
+ break;
+ }
+ }
+}
+
+static void merge_stmts(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to, const struct merge *merge)
+{
+ struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
+
+ switch (stmt_a->ops->type) {
+ case STMT_EXPRESSION:
+ merge_expr_stmts(ctx, from, to, merge, stmt_a);
+ break;
+ case STMT_VERDICT:
+ merge_verdict_stmts(ctx, from, to, merge, stmt_a);
+ break;
+ default:
+ assert(0);
+ }
+}
+
+static void __merge_concat(const struct optimize_ctx *ctx, uint32_t i,
+ const struct merge *merge, struct list_head *concat_list)
+{
+ struct expr *concat, *next, *expr, *concat_clone, *clone;
+ LIST_HEAD(pending_list);
+ struct stmt *stmt_a;
+ uint32_t k;
+
+ concat = concat_expr_alloc(&internal_location);
+ list_add(&concat->list, concat_list);
+
+ for (k = 0; k < merge->num_stmts; k++) {
+ list_for_each_entry_safe(concat, next, concat_list, list) {
+ stmt_a = ctx->stmt_matrix[i][merge->stmt[k]];
+ switch (stmt_a->expr->right->etype) {
+ case EXPR_SET:
+ list_for_each_entry(expr, &stmt_a->expr->right->expressions, list) {
+ concat_clone = expr_clone(concat);
+ clone = expr_clone(expr->key);
+ compound_expr_add(concat_clone, clone);
+ list_add_tail(&concat_clone->list, &pending_list);
+ }
+ list_del(&concat->list);
+ expr_free(concat);
+ break;
+ case EXPR_SYMBOL:
+ case EXPR_VALUE:
+ case EXPR_PREFIX:
+ case EXPR_RANGE:
+ clone = expr_clone(stmt_a->expr->right);
+ compound_expr_add(concat, clone);
+ break;
+ default:
+ assert(0);
+ break;
+ }
+ }
+ list_splice_init(&pending_list, concat_list);
+ }
+}
+
+static void __merge_concat_stmts(const struct optimize_ctx *ctx, uint32_t i,
+ const struct merge *merge, struct expr *set)
+{
+ struct expr *concat, *next, *elem;
+ LIST_HEAD(concat_list);
+
+ __merge_concat(ctx, i, merge, &concat_list);
+
+ list_for_each_entry_safe(concat, next, &concat_list, list) {
+ list_del(&concat->list);
+ elem = set_elem_expr_alloc(&internal_location, concat);
+ compound_expr_add(set, elem);
+ }
+}
+
+static void merge_concat_stmts(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge)
+{
+ struct stmt *stmt, *stmt_a;
+ struct expr *concat, *set;
+ uint32_t i, k;
+
+ stmt = ctx->stmt_matrix[from][merge->stmt[0]];
+ /* build concatenation of selectors, eg. ifname . ip daddr . tcp dport */
+ concat = concat_expr_alloc(&internal_location);
+
+ for (k = 0; k < merge->num_stmts; k++) {
+ stmt_a = ctx->stmt_matrix[from][merge->stmt[k]];
+ compound_expr_add(concat, expr_get(stmt_a->expr->left));
+ }
+ expr_free(stmt->expr->left);
+ stmt->expr->left = concat;
+
+ /* build set data contenation, eg. { eth0 . 1.1.1.1 . 22 } */
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ for (i = from; i <= to; i++)
+ __merge_concat_stmts(ctx, i, merge, set);
+
+ expr_free(stmt->expr->right);
+ stmt->expr->right = set;
+
+ for (k = 1; k < merge->num_stmts; k++) {
+ stmt_a = ctx->stmt_matrix[from][merge->stmt[k]];
+ list_del(&stmt_a->list);
+ stmt_free(stmt_a);
+ }
+}
+
+static void build_verdict_map(struct expr *expr, struct stmt *verdict,
+ struct expr *set, struct stmt *counter)
+{
+ struct expr *item, *elem, *mapping;
+
+ switch (expr->etype) {
+ case EXPR_LIST:
+ list_for_each_entry(item, &expr->expressions, list) {
+ elem = set_elem_expr_alloc(&internal_location, expr_get(item));
+ if (counter)
+ list_add_tail(&counter->list, &elem->stmt_list);
+
+ mapping = mapping_expr_alloc(&internal_location, elem,
+ expr_get(verdict->expr));
+ compound_expr_add(set, mapping);
+ }
+ break;
+ case EXPR_SET:
+ list_for_each_entry(item, &expr->expressions, list) {
+ elem = set_elem_expr_alloc(&internal_location, expr_get(item->key));
+ if (counter)
+ list_add_tail(&counter->list, &elem->stmt_list);
+
+ mapping = mapping_expr_alloc(&internal_location, elem,
+ expr_get(verdict->expr));
+ compound_expr_add(set, mapping);
+ }
+ break;
+ case EXPR_PREFIX:
+ case EXPR_RANGE:
+ case EXPR_VALUE:
+ case EXPR_SYMBOL:
+ case EXPR_CONCAT:
+ elem = set_elem_expr_alloc(&internal_location, expr_get(expr));
+ if (counter)
+ list_add_tail(&counter->list, &elem->stmt_list);
+
+ mapping = mapping_expr_alloc(&internal_location, elem,
+ expr_get(verdict->expr));
+ compound_expr_add(set, mapping);
+ break;
+ default:
+ assert(0);
+ break;
+ }
+}
+
+static void remove_counter(const struct optimize_ctx *ctx, uint32_t from)
+{
+ struct stmt *stmt;
+ uint32_t i;
+
+ /* remove counter statement */
+ for (i = 0; i < ctx->num_stmts; i++) {
+ stmt = ctx->stmt_matrix[from][i];
+ if (!stmt)
+ continue;
+
+ if (stmt->ops->type == STMT_COUNTER) {
+ list_del(&stmt->list);
+ stmt_free(stmt);
+ }
+ }
+}
+
+static struct stmt *zap_counter(const struct optimize_ctx *ctx, uint32_t from)
+{
+ struct stmt *stmt;
+ uint32_t i;
+
+ /* remove counter statement */
+ for (i = 0; i < ctx->num_stmts; i++) {
+ stmt = ctx->stmt_matrix[from][i];
+ if (!stmt)
+ continue;
+
+ if (stmt->ops->type == STMT_COUNTER) {
+ list_del(&stmt->list);
+ return stmt;
+ }
+ }
+
+ return NULL;
+}
+
+static void merge_stmts_vmap(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge)
+{
+ struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
+ struct stmt *stmt_b, *verdict_a, *verdict_b, *stmt;
+ struct expr *expr_a, *expr_b, *expr, *left, *set;
+ struct stmt *counter;
+ uint32_t i;
+ int k;
+
+ k = stmt_verdict_find(ctx);
+ assert(k >= 0);
+
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ expr_a = stmt_a->expr->right;
+ verdict_a = ctx->stmt_matrix[from][k];
+ counter = zap_counter(ctx, from);
+ build_verdict_map(expr_a, verdict_a, set, counter);
+
+ for (i = from + 1; i <= to; i++) {
+ stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
+ expr_b = stmt_b->expr->right;
+ verdict_b = ctx->stmt_matrix[i][k];
+ counter = zap_counter(ctx, i);
+ build_verdict_map(expr_b, verdict_b, set, counter);
+ }
+
+ left = expr_get(stmt_a->expr->left);
+ expr = map_expr_alloc(&internal_location, left, set);
+ stmt = verdict_stmt_alloc(&internal_location, expr);
+
+ list_add(&stmt->list, &stmt_a->list);
+ list_del(&stmt_a->list);
+ stmt_free(stmt_a);
+ list_del(&verdict_a->list);
+ stmt_free(verdict_a);
+}
+
+static void __merge_concat_stmts_vmap(const struct optimize_ctx *ctx,
+ uint32_t i, const struct merge *merge,
+ struct expr *set, struct stmt *verdict)
+{
+ struct expr *concat, *next, *elem, *mapping;
+ LIST_HEAD(concat_list);
+ struct stmt *counter;
+
+ counter = zap_counter(ctx, i);
+ __merge_concat(ctx, i, merge, &concat_list);
+
+ list_for_each_entry_safe(concat, next, &concat_list, list) {
+ list_del(&concat->list);
+ elem = set_elem_expr_alloc(&internal_location, concat);
+ if (counter)
+ list_add_tail(&counter->list, &elem->stmt_list);
+
+ mapping = mapping_expr_alloc(&internal_location, elem,
+ expr_get(verdict->expr));
+ compound_expr_add(set, mapping);
+ }
+}
+
+static void merge_concat_stmts_vmap(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge)
+{
+ struct stmt *orig_stmt = ctx->stmt_matrix[from][merge->stmt[0]];
+ struct stmt *stmt, *stmt_a, *verdict;
+ struct expr *concat_a, *expr, *set;
+ uint32_t i;
+ int k;
+
+ k = stmt_verdict_find(ctx);
+ assert(k >= 0);
+
+ /* build concatenation of selectors, eg. ifname . ip daddr . tcp dport */
+ concat_a = concat_expr_alloc(&internal_location);
+ for (i = 0; i < merge->num_stmts; i++) {
+ stmt_a = ctx->stmt_matrix[from][merge->stmt[i]];
+ compound_expr_add(concat_a, expr_get(stmt_a->expr->left));
+ }
+
+ /* build set data contenation, eg. { eth0 . 1.1.1.1 . 22 : accept } */
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ for (i = from; i <= to; i++) {
+ verdict = ctx->stmt_matrix[i][k];
+ __merge_concat_stmts_vmap(ctx, i, merge, set, verdict);
+ }
+
+ expr = map_expr_alloc(&internal_location, concat_a, set);
+ stmt = verdict_stmt_alloc(&internal_location, expr);
+
+ list_add(&stmt->list, &orig_stmt->list);
+ list_del(&orig_stmt->list);
+ stmt_free(orig_stmt);
+
+ for (i = 1; i < merge->num_stmts; i++) {
+ stmt_a = ctx->stmt_matrix[from][merge->stmt[i]];
+ list_del(&stmt_a->list);
+ stmt_free(stmt_a);
+ }
+
+ verdict = ctx->stmt_matrix[from][k];
+ list_del(&verdict->list);
+ stmt_free(verdict);
+}
+
+static bool stmt_verdict_cmp(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to)
+{
+ struct stmt *stmt_a, *stmt_b;
+ uint32_t i;
+ int k;
+
+ k = stmt_verdict_find(ctx);
+ if (k < 0)
+ return true;
+
+ for (i = from; i + 1 <= to; i++) {
+ stmt_a = ctx->stmt_matrix[i][k];
+ stmt_b = ctx->stmt_matrix[i + 1][k];
+ if (!stmt_a && !stmt_b)
+ continue;
+ if (!stmt_a || !stmt_b)
+ return false;
+ if (!stmt_verdict_eq(stmt_a, stmt_b))
+ return false;
+ }
+
+ return true;
+}
+
+static int stmt_nat_type(const struct optimize_ctx *ctx, int from,
+ enum nft_nat_etypes *nat_type)
+{
+ uint32_t j;
+
+ for (j = 0; j < ctx->num_stmts; j++) {
+ if (!ctx->stmt_matrix[from][j])
+ continue;
+
+ if (ctx->stmt_matrix[from][j]->ops->type == STMT_NAT) {
+ *nat_type = ctx->stmt_matrix[from][j]->nat.type;
+ return 0;
+ }
+ }
+
+ return -1;
+}
+
+static int stmt_nat_find(const struct optimize_ctx *ctx, int from)
+{
+ enum nft_nat_etypes nat_type;
+ uint32_t i;
+
+ if (stmt_nat_type(ctx, from, &nat_type) < 0)
+ return -1;
+
+ for (i = 0; i < ctx->num_stmts; i++) {
+ if (ctx->stmt[i]->ops->type != STMT_NAT ||
+ ctx->stmt[i]->nat.type != nat_type)
+ continue;
+
+ return i;
+ }
+
+ return -1;
+}
+
+static struct expr *stmt_nat_expr(struct stmt *nat_stmt)
+{
+ struct expr *nat_expr;
+
+ assert(nat_stmt->ops->type == STMT_NAT);
+
+ if (nat_stmt->nat.proto) {
+ if (nat_stmt->nat.addr) {
+ nat_expr = concat_expr_alloc(&internal_location);
+ compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr));
+ compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto));
+ } else {
+ nat_expr = expr_get(nat_stmt->nat.proto);
+ }
+ expr_free(nat_stmt->nat.proto);
+ nat_stmt->nat.proto = NULL;
+ } else {
+ nat_expr = expr_get(nat_stmt->nat.addr);
+ }
+
+ assert(nat_expr);
+
+ return nat_expr;
+}
+
+static void merge_nat(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge)
+{
+ struct expr *expr, *set, *elem, *nat_expr, *mapping, *left;
+ int k, family = NFPROTO_UNSPEC;
+ struct stmt *stmt, *nat_stmt;
+ uint32_t i;
+
+ k = stmt_nat_find(ctx, from);
+ assert(k >= 0);
+
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ for (i = from; i <= to; i++) {
+ stmt = ctx->stmt_matrix[i][merge->stmt[0]];
+ expr = stmt->expr->right;
+
+ nat_stmt = ctx->stmt_matrix[i][k];
+ nat_expr = stmt_nat_expr(nat_stmt);
+
+ elem = set_elem_expr_alloc(&internal_location, expr_get(expr));
+ mapping = mapping_expr_alloc(&internal_location, elem, nat_expr);
+ compound_expr_add(set, mapping);
+ }
+
+ stmt = ctx->stmt_matrix[from][merge->stmt[0]];
+ left = expr_get(stmt->expr->left);
+ if (left->etype == EXPR_PAYLOAD) {
+ if (left->payload.desc == &proto_ip)
+ family = NFPROTO_IPV4;
+ else if (left->payload.desc == &proto_ip6)
+ family = NFPROTO_IPV6;
+ }
+ expr = map_expr_alloc(&internal_location, left, set);
+
+ nat_stmt = ctx->stmt_matrix[from][k];
+ if (nat_stmt->nat.family == NFPROTO_UNSPEC)
+ nat_stmt->nat.family = family;
+
+ expr_free(nat_stmt->nat.addr);
+ if (nat_stmt->nat.type == NFT_NAT_REDIR)
+ nat_stmt->nat.proto = expr;
+ else
+ nat_stmt->nat.addr = expr;
+
+ remove_counter(ctx, from);
+ list_del(&stmt->list);
+ stmt_free(stmt);
+}
+
+static void merge_concat_nat(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge)
+{
+ struct expr *expr, *set, *elem, *nat_expr, *mapping, *left, *concat;
+ int k, family = NFPROTO_UNSPEC;
+ struct stmt *stmt, *nat_stmt;
+ uint32_t i, j;
+
+ k = stmt_nat_find(ctx, from);
+ assert(k >= 0);
+
+ set = set_expr_alloc(&internal_location, NULL);
+ set->set_flags |= NFT_SET_ANONYMOUS;
+
+ for (i = from; i <= to; i++) {
+
+ concat = concat_expr_alloc(&internal_location);
+ for (j = 0; j < merge->num_stmts; j++) {
+ stmt = ctx->stmt_matrix[i][merge->stmt[j]];
+ expr = stmt->expr->right;
+ compound_expr_add(concat, expr_get(expr));
+ }
+
+ nat_stmt = ctx->stmt_matrix[i][k];
+ nat_expr = stmt_nat_expr(nat_stmt);
+
+ elem = set_elem_expr_alloc(&internal_location, concat);
+ mapping = mapping_expr_alloc(&internal_location, elem, nat_expr);
+ compound_expr_add(set, mapping);
+ }
+
+ concat = concat_expr_alloc(&internal_location);
+ for (j = 0; j < merge->num_stmts; j++) {
+ stmt = ctx->stmt_matrix[from][merge->stmt[j]];
+ left = stmt->expr->left;
+ if (left->etype == EXPR_PAYLOAD) {
+ if (left->payload.desc == &proto_ip)
+ family = NFPROTO_IPV4;
+ else if (left->payload.desc == &proto_ip6)
+ family = NFPROTO_IPV6;
+ }
+ compound_expr_add(concat, expr_get(left));
+ }
+ expr = map_expr_alloc(&internal_location, concat, set);
+
+ nat_stmt = ctx->stmt_matrix[from][k];
+ if (nat_stmt->nat.family == NFPROTO_UNSPEC)
+ nat_stmt->nat.family = family;
+
+ expr_free(nat_stmt->nat.addr);
+ nat_stmt->nat.addr = expr;
+
+ remove_counter(ctx, from);
+ for (j = 0; j < merge->num_stmts; j++) {
+ stmt = ctx->stmt_matrix[from][merge->stmt[j]];
+ list_del(&stmt->list);
+ stmt_free(stmt);
+ }
+}
+
+static void rule_optimize_print(struct output_ctx *octx,
+ const struct rule *rule)
+{
+ const struct location *loc = &rule->location;
+ const struct input_descriptor *indesc = loc->indesc;
+ const char *line = "";
+ char buf[1024];
+
+ switch (indesc->type) {
+ case INDESC_BUFFER:
+ case INDESC_CLI:
+ line = indesc->data;
+ *strchrnul(line, '\n') = '\0';
+ break;
+ case INDESC_STDIN:
+ line = indesc->data;
+ line += loc->line_offset;
+ *strchrnul(line, '\n') = '\0';
+ break;
+ case INDESC_FILE:
+ line = line_location(indesc, loc, buf, sizeof(buf));
+ break;
+ case INDESC_INTERNAL:
+ case INDESC_NETLINK:
+ break;
+ default:
+ BUG("invalid input descriptor type %u\n", indesc->type);
+ }
+
+ print_location(octx->error_fp, indesc, loc);
+ fprintf(octx->error_fp, "%s\n", line);
+}
+
+enum {
+ MERGE_BY_VERDICT,
+ MERGE_BY_NAT_MAP,
+ MERGE_BY_NAT,
+};
+
+static uint32_t merge_stmt_type(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to)
+{
+ const struct stmt *stmt;
+ uint32_t i, j;
+
+ for (i = from; i <= to; i++) {
+ for (j = 0; j < ctx->num_stmts; j++) {
+ stmt = ctx->stmt_matrix[i][j];
+ if (!stmt)
+ continue;
+ if (stmt->ops->type == STMT_NAT) {
+ if ((stmt->nat.type == NFT_NAT_REDIR &&
+ !stmt->nat.proto) ||
+ stmt->nat.type == NFT_NAT_MASQ)
+ return MERGE_BY_NAT;
+
+ return MERGE_BY_NAT_MAP;
+ }
+ }
+ }
+
+ /* merge by verdict, even if no verdict is specified. */
+ return MERGE_BY_VERDICT;
+}
+
+static void merge_rules(const struct optimize_ctx *ctx,
+ uint32_t from, uint32_t to,
+ const struct merge *merge,
+ struct output_ctx *octx)
+{
+ uint32_t merge_type;
+ bool same_verdict;
+ uint32_t i;
+
+ merge_type = merge_stmt_type(ctx, from, to);
+
+ switch (merge_type) {
+ case MERGE_BY_VERDICT:
+ same_verdict = stmt_verdict_cmp(ctx, from, to);
+ if (merge->num_stmts > 1) {
+ if (same_verdict)
+ merge_concat_stmts(ctx, from, to, merge);
+ else
+ merge_concat_stmts_vmap(ctx, from, to, merge);
+ } else {
+ if (same_verdict)
+ merge_stmts(ctx, from, to, merge);
+ else
+ merge_stmts_vmap(ctx, from, to, merge);
+ }
+ break;
+ case MERGE_BY_NAT_MAP:
+ if (merge->num_stmts > 1)
+ merge_concat_nat(ctx, from, to, merge);
+ else
+ merge_nat(ctx, from, to, merge);
+ break;
+ case MERGE_BY_NAT:
+ if (merge->num_stmts > 1)
+ merge_concat_stmts(ctx, from, to, merge);
+ else
+ merge_stmts(ctx, from, to, merge);
+ break;
+ default:
+ assert(0);
+ }
+
+ if (ctx->rule[from]->comment) {
+ xfree(ctx->rule[from]->comment);
+ ctx->rule[from]->comment = NULL;
+ }
+
+ octx->flags |= NFT_CTX_OUTPUT_STATELESS;
+
+ fprintf(octx->error_fp, "Merging:\n");
+ rule_optimize_print(octx, ctx->rule[from]);
+
+ for (i = from + 1; i <= to; i++) {
+ rule_optimize_print(octx, ctx->rule[i]);
+ list_del(&ctx->rule[i]->list);
+ rule_free(ctx->rule[i]);
+ }
+
+ fprintf(octx->error_fp, "into:\n\t");
+ rule_print(ctx->rule[from], octx);
+ fprintf(octx->error_fp, "\n");
+
+ octx->flags &= ~NFT_CTX_OUTPUT_STATELESS;
+}
+
+static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+ if (!stmt_a && !stmt_b)
+ return true;
+ else if (!stmt_a)
+ return false;
+ else if (!stmt_b)
+ return false;
+
+ return __stmt_type_eq(stmt_a, stmt_b, true);
+}
+
+static bool stmt_is_mergeable(const struct stmt *stmt)
+{
+ if (!stmt)
+ return false;
+
+ switch (stmt->ops->type) {
+ case STMT_VERDICT:
+ if (stmt->expr->etype == EXPR_MAP)
+ return true;
+ break;
+ case STMT_EXPRESSION:
+ case STMT_NAT:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+static bool rules_eq(const struct optimize_ctx *ctx, int i, int j)
+{
+ uint32_t k, mergeable = 0;
+
+ for (k = 0; k < ctx->num_stmts; k++) {
+ if (stmt_is_mergeable(ctx->stmt_matrix[i][k]))
+ mergeable++;
+
+ if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k]))
+ return false;
+ }
+
+ if (mergeable == 0)
+ return false;
+
+ return true;
+}
+
+static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
+{
+ struct optimize_ctx *ctx;
+ uint32_t num_merges = 0;
+ struct merge *merge;
+ uint32_t i, j, m, k;
+ struct rule *rule;
+ int ret;
+
+ ctx = xzalloc(sizeof(*ctx));
+
+ /* Step 1: collect statements in rules */
+ list_for_each_entry(rule, rules, list) {
+ ret = rule_collect_stmts(ctx, rule);
+ if (ret < 0)
+ goto err;
+
+ ctx->num_rules++;
+ }
+
+ ctx->rule = xzalloc(sizeof(*ctx->rule) * ctx->num_rules);
+ ctx->stmt_matrix = xzalloc(sizeof(*ctx->stmt_matrix) * ctx->num_rules);
+ for (i = 0; i < ctx->num_rules; i++)
+ ctx->stmt_matrix[i] = xzalloc_array(MAX_STMTS,
+ sizeof(**ctx->stmt_matrix));
+
+ merge = xzalloc(sizeof(*merge) * ctx->num_rules);
+
+ /* Step 2: Build matrix of statements */
+ i = 0;
+ list_for_each_entry(rule, rules, list)
+ rule_build_stmt_matrix_stmts(ctx, rule, &i);
+
+ /* Step 3: Look for common selectors for possible rule mergers */
+ for (i = 0; i < ctx->num_rules; i++) {
+ for (j = i + 1; j < ctx->num_rules; j++) {
+ if (!rules_eq(ctx, i, j)) {
+ if (merge[num_merges].num_rules > 0)
+ num_merges++;
+
+ i = j - 1;
+ break;
+ }
+ if (merge[num_merges].num_rules > 0) {
+ merge[num_merges].num_rules++;
+ } else {
+ merge[num_merges].rule_from = i;
+ merge[num_merges].num_rules = 2;
+ }
+ }
+ if (j == ctx->num_rules && merge[num_merges].num_rules > 0) {
+ num_merges++;
+ break;
+ }
+ }
+
+ /* Step 4: Infer how to merge the candidate rules */
+ for (k = 0; k < num_merges; k++) {
+ i = merge[k].rule_from;
+
+ for (m = 0; m < ctx->num_stmts; m++) {
+ if (!ctx->stmt_matrix[i][m])
+ continue;
+ switch (ctx->stmt_matrix[i][m]->ops->type) {
+ case STMT_EXPRESSION:
+ merge[k].stmt[merge[k].num_stmts++] = m;
+ break;
+ case STMT_VERDICT:
+ if (ctx->stmt_matrix[i][m]->expr->etype == EXPR_MAP)
+ merge[k].stmt[merge[k].num_stmts++] = m;
+ break;
+ default:
+ break;
+ }
+ }
+
+ j = merge[k].num_rules - 1;
+ merge_rules(ctx, i, i + j, &merge[k], &nft->output);
+ }
+ ret = 0;
+ for (i = 0; i < ctx->num_rules; i++)
+ xfree(ctx->stmt_matrix[i]);
+
+ xfree(ctx->stmt_matrix);
+ xfree(merge);
+err:
+ for (i = 0; i < ctx->num_stmts; i++)
+ stmt_free(ctx->stmt[i]);
+
+ xfree(ctx->rule);
+ xfree(ctx);
+
+ return ret;
+}
+
+static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd)
+{
+ struct table *table;
+ struct chain *chain;
+ int ret = 0;
+
+ switch (cmd->obj) {
+ case CMD_OBJ_TABLE:
+ table = cmd->table;
+ if (!table)
+ break;
+
+ list_for_each_entry(chain, &table->chains, list) {
+ if (chain->flags & CHAIN_F_HW_OFFLOAD)
+ continue;
+
+ chain_optimize(nft, &chain->rules);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return ret;
+}
+
+int nft_optimize(struct nft_ctx *nft, struct list_head *cmds)
+{
+ struct cmd *cmd;
+ int ret = 0;
+
+ list_for_each_entry(cmd, cmds, list) {
+ switch (cmd->op) {
+ case CMD_ADD:
+ ret = cmd_optimize(nft, cmd);
+ break;
+ default:
+ break;
+ }
+ }
+
+ return ret;
+}