summaryrefslogtreecommitdiffstats
path: root/benchmarks
diff options
context:
space:
mode:
Diffstat (limited to 'benchmarks')
-rw-r--r--benchmarks/__init__.py0
-rw-r--r--benchmarks/bench.py29
-rw-r--r--benchmarks/helpers.py28
-rw-r--r--benchmarks/optimize.py70
4 files changed, 107 insertions, 20 deletions
diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/benchmarks/__init__.py
diff --git a/benchmarks/bench.py b/benchmarks/bench.py
index 917266d..51872db 100644
--- a/benchmarks/bench.py
+++ b/benchmarks/bench.py
@@ -1,5 +1,7 @@
import collections.abc
+from benchmarks.helpers import ascii_table
+
# moz_sql_parser 3.10 compatibility
collections.Iterable = collections.abc.Iterable
import timeit
@@ -188,11 +190,6 @@ def sqlfluff_parse(sql):
sqlfluff.parse(sql)
-def border(columns):
- columns = " | ".join(columns)
- return f"| {columns} |"
-
-
def diff(row, column):
if column == "Query":
return ""
@@ -223,19 +220,11 @@ for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.it
print(e)
row[lib] = "error"
-columns = ["Query"] + libs
-widths = {column: max(len(column), 15) for column in columns}
-
-lines = [border(column.rjust(width) for column, width in widths.items())]
-lines.append(border(str("-" * width) for width in widths.values()))
-
-for i, row in enumerate(table):
- lines.append(
- border(
- (str(row[column])[0:7] + diff(row, column)).rjust(width)[0:width]
- for column, width in widths.items()
- )
+print(
+ ascii_table(
+ [
+ {k: v if v == "Query" else str(row[k])[0:7] + diff(row, k) for k, v in row.items()}
+ for row in table
+ ]
)
-
-for line in lines:
- print(line)
+)
diff --git a/benchmarks/helpers.py b/benchmarks/helpers.py
new file mode 100644
index 0000000..bfb9821
--- /dev/null
+++ b/benchmarks/helpers.py
@@ -0,0 +1,28 @@
+import typing as t
+
+
+def border(columns: t.Iterable[str]) -> str:
+ columns = " | ".join(columns)
+ return f"| {columns} |"
+
+
+def ascii_table(table: list[dict[str, t.Any]]) -> str:
+ columns = []
+ for row in table:
+ for key in row:
+ if key not in columns:
+ columns.append(key)
+
+ widths = {column: max(len(column), 15) for column in columns}
+
+ lines = [
+ border(column.rjust(width) for column, width in widths.items()),
+ border(str("-" * width) for width in widths.values()),
+ ]
+
+ for row in table:
+ lines.append(
+ border(str(row[column]).rjust(width)[0:width] for column, width in widths.items())
+ )
+
+ return "\n".join(lines)
diff --git a/benchmarks/optimize.py b/benchmarks/optimize.py
new file mode 100644
index 0000000..95a6821
--- /dev/null
+++ b/benchmarks/optimize.py
@@ -0,0 +1,70 @@
+import typing as t
+from argparse import ArgumentParser
+
+from benchmarks.helpers import ascii_table
+from sqlglot.optimizer import optimize
+from sqlglot import parse_one
+from tests.helpers import load_sql_fixture_pairs, TPCH_SCHEMA, TPCDS_SCHEMA
+from timeit import Timer
+import sys
+
+# Deeply nested conditions currently require a lot of recursion
+sys.setrecursionlimit(10000)
+
+
+def gen_condition(n):
+ return parse_one(" OR ".join(f"a = {i} AND b = {i}" for i in range(n)))
+
+
+BENCHMARKS = {
+ "tpch": lambda: (
+ [parse_one(sql) for _, sql, _ in load_sql_fixture_pairs(f"optimizer/tpc-h/tpc-h.sql")],
+ TPCH_SCHEMA,
+ 3,
+ ),
+ "tpcds": lambda: (
+ [parse_one(sql) for _, sql, _ in load_sql_fixture_pairs(f"optimizer/tpc-ds/tpc-ds.sql")],
+ TPCDS_SCHEMA,
+ 3,
+ ),
+ "condition_10": lambda: (
+ [gen_condition(10)],
+ {},
+ 10,
+ ),
+ "condition_100": lambda: (
+ [gen_condition(100)],
+ {},
+ 10,
+ ),
+ "condition_1000": lambda: (
+ [gen_condition(1000)],
+ {},
+ 3,
+ ),
+}
+
+
+def bench() -> list[dict[str, t.Any]]:
+ parser = ArgumentParser()
+ parser.add_argument("-b", "--benchmark", choices=BENCHMARKS, action="append")
+ args = parser.parse_args()
+ benchmarks = list(args.benchmark or BENCHMARKS)
+
+ table = []
+ for benchmark in benchmarks:
+ expressions, schema, n = BENCHMARKS[benchmark]()
+
+ def func():
+ for e in expressions:
+ optimize(e, schema)
+
+ timer = Timer(func)
+ min_duration = min(timer.repeat(repeat=n, number=1))
+ table.append({"Benchmark": benchmark, "Duration (s)": round(min_duration, 4)})
+
+ return table
+
+
+if __name__ == "__main__":
+ print(ascii_table(bench()))