diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-15 16:46:17 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-15 16:46:17 +0000 |
commit | 28cc22419e32a65fea2d1678400265b8cabc3aff (patch) | |
tree | ff9ac1991fd48490b21ef6aa9015a347a165e2d9 /sqlglot/executor | |
parent | Initial commit. (diff) | |
download | sqlglot-28cc22419e32a65fea2d1678400265b8cabc3aff.tar.xz sqlglot-28cc22419e32a65fea2d1678400265b8cabc3aff.zip |
Adding upstream version 6.0.4.upstream/6.0.4
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/executor')
-rw-r--r-- | sqlglot/executor/__init__.py | 39 | ||||
-rw-r--r-- | sqlglot/executor/context.py | 68 | ||||
-rw-r--r-- | sqlglot/executor/env.py | 32 | ||||
-rw-r--r-- | sqlglot/executor/python.py | 360 | ||||
-rw-r--r-- | sqlglot/executor/table.py | 81 |
5 files changed, 580 insertions, 0 deletions
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py new file mode 100644 index 0000000..a437431 --- /dev/null +++ b/sqlglot/executor/__init__.py @@ -0,0 +1,39 @@ +import logging +import time + +from sqlglot import parse_one +from sqlglot.executor.python import PythonExecutor +from sqlglot.optimizer import optimize +from sqlglot.planner import Plan + +logger = logging.getLogger("sqlglot") + + +def execute(sql, schema, read=None): + """ + Run a sql query against data. + + Args: + sql (str): a sql statement + schema (dict|sqlglot.optimizer.Schema): database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + Returns: + sqlglot.executor.Table: Simple columnar data structure. + """ + expression = parse_one(sql, read=read) + now = time.time() + expression = optimize(expression, schema) + logger.debug("Optimization finished: %f", time.time() - now) + logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) + plan = Plan(expression) + logger.debug("Logical Plan: %s", plan) + now = time.time() + result = PythonExecutor().execute(plan) + logger.debug("Query finished: %f", time.time() - now) + return result diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py new file mode 100644 index 0000000..457bea7 --- /dev/null +++ b/sqlglot/executor/context.py @@ -0,0 +1,68 @@ +from sqlglot.executor.env import ENV + + +class Context: + """ + Execution context for sql expressions. + + Context is used to hold relevant data tables which can then be queried on with eval. + + References to columns can either be scalar or vectors. When set_row is used, column references + evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient + evaluation of aggregation functions. + """ + + def __init__(self, tables, env=None): + """ + Args + tables (dict): table_name -> Table, representing the scope of the current execution context + env (Optional[dict]): dictionary of functions within the execution context + """ + self.tables = tables + self.range_readers = { + name: table.range_reader for name, table in self.tables.items() + } + self.row_readers = {name: table.reader for name, table in tables.items()} + self.env = {**(env or {}), "scope": self.row_readers} + + def eval(self, code): + return eval(code, ENV, self.env) + + def eval_tuple(self, codes): + return tuple(self.eval(code) for code in codes) + + def __iter__(self): + return self.table_iter(list(self.tables)[0]) + + def table_iter(self, table): + self.env["scope"] = self.row_readers + + for reader in self.tables[table]: + yield reader, self + + def sort(self, table, key): + table = self.tables[table] + + def sort_key(row): + table.reader.row = row + return self.eval_tuple(key) + + table.rows.sort(key=sort_key) + + def set_row(self, table, row): + self.row_readers[table].row = row + self.env["scope"] = self.row_readers + + def set_index(self, table, index): + self.row_readers[table].row = self.tables[table].rows[index] + self.env["scope"] = self.row_readers + + def set_range(self, table, start, end): + self.range_readers[table].range = range(start, end) + self.env["scope"] = self.range_readers + + def __getitem__(self, table): + return self.env["scope"][table] + + def __contains__(self, table): + return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py new file mode 100644 index 0000000..72b0558 --- /dev/null +++ b/sqlglot/executor/env.py @@ -0,0 +1,32 @@ +import datetime +import re +import statistics + + +class reverse_key: + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + return other.obj == self.obj + + def __lt__(self, other): + return other.obj < self.obj + + +ENV = { + "__builtins__": {}, + "datetime": datetime, + "locals": locals, + "re": re, + "float": float, + "int": int, + "str": str, + "desc": reverse_key, + "SUM": sum, + "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, + "COUNT": lambda acc: sum(1 for e in acc if e is not None), + "MAX": max, + "MIN": min, + "POW": pow, +} diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py new file mode 100644 index 0000000..388a419 --- /dev/null +++ b/sqlglot/executor/python.py @@ -0,0 +1,360 @@ +import ast +import collections +import itertools + +from sqlglot import exp, planner +from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.executor.context import Context +from sqlglot.executor.env import ENV +from sqlglot.executor.table import Table +from sqlglot.generator import Generator +from sqlglot.helper import csv_reader +from sqlglot.tokens import Tokenizer + + +class PythonExecutor: + def __init__(self, env=None): + self.generator = Python().generator(identify=True) + self.env = {**ENV, **(env or {})} + + def execute(self, plan): + running = set() + finished = set() + queue = set(plan.leaves) + contexts = {} + + while queue: + node = queue.pop() + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + running.add(node) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + else: + raise NotImplementedError + + running.remove(node) + finished.add(node) + + for dep in node.dependents: + if dep not in running and all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + + root = plan.root + return contexts[root].tables[root.name] + + def generate(self, expression): + """Convert a SQL expression into literal Python code and compile it into bytecode.""" + if not expression: + return None + + sql = self.generator.generate(expression) + return compile(sql, sql, "eval", optimize=2) + + def generate_tuple(self, expressions): + """Convert an array of SQL expressions into tuple of Python byte code.""" + if not expressions: + return tuple() + return tuple(self.generate(expression) for expression in expressions) + + def context(self, tables): + return Context(tables, env=self.env) + + def table(self, expressions): + return Table(expression.alias_or_name for expression in expressions) + + def scan(self, step, context): + if hasattr(step, "source"): + source = step.source + + if isinstance(source, exp.Expression): + source = source.this.name or source.alias + else: + source = step.name + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) + + if source in context: + if not projections and not condition: + return self.context({step.name: context.tables[source]}) + table_iter = context.table_iter(source) + else: + table_iter = self.scan_csv(step) + + if projections: + sink = self.table(step.projections) + elif source in context: + sink = Table(context[source].columns) + else: + sink = None + + for reader, ctx in table_iter: + if sink is None: + sink = Table(ctx[source].columns) + + if condition and not ctx.eval(condition): + continue + + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) + + if len(sink) >= step.limit: + break + + return self.context({step.name: sink}) + + def scan_csv(self, step): + source = step.source + alias = source.alias + + with csv_reader(source.this) as reader: + columns = next(reader) + table = Table(columns) + context = self.context({alias: table}) + types = [] + + for row in reader: + if not types: + for v in row: + try: + types.append(type(ast.literal_eval(v))) + except (ValueError, SyntaxError): + types.append(str) + context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) + yield context[alias], context + + def join(self, step, context): + source = step.name + + join_context = self.context({source: context.tables[source]}) + + def merge_context(ctx, table): + # create a new context where all existing tables are mapped to a new one + return self.context({name: table for name in ctx.tables}) + + for name, join in step.joins.items(): + join_context = self.context( + {**join_context.tables, name: context.tables[name]} + ) + + if join.get("source_key"): + table = self.hash_join(join, source, name, join_context) + else: + table = self.nested_loop_join(join, source, name, join_context) + + join_context = merge_context(join_context, table) + + # apply projections or conditions + context = self.scan(step, join_context) + + # use the scan context since it returns a single table + # otherwise there are no projections so all other tables are still in scope + if step.projections: + return context + + return merge_context(join_context, context.tables[source]) + + def nested_loop_join(self, _join, a, b, context): + table = Table(context.tables[a].columns + context.tables[b].columns) + + for reader_a, _ in context.table_iter(a): + for reader_b, _ in context.table_iter(b): + table.append(reader_a.row + reader_b.row) + + return table + + def hash_join(self, join, a, b, context): + a_key = self.generate_tuple(join["source_key"]) + b_key = self.generate_tuple(join["join_key"]) + + results = collections.defaultdict(lambda: ([], [])) + + for reader, ctx in context.table_iter(a): + results[ctx.eval_tuple(a_key)][0].append(reader.row) + for reader, ctx in context.table_iter(b): + results[ctx.eval_tuple(b_key)][1].append(reader.row) + + table = Table(context.tables[a].columns + context.tables[b].columns) + for a_group, b_group in results.values(): + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def sort_merge_join(self, join, a, b, context): + a_key = self.generate_tuple(join["source_key"]) + b_key = self.generate_tuple(join["join_key"]) + + context.sort(a, a_key) + context.sort(b, b_key) + + a_i = 0 + b_i = 0 + a_n = len(context.tables[a]) + b_n = len(context.tables[b]) + + table = Table(context.tables[a].columns + context.tables[b].columns) + + def get_key(source, key, i): + context.set_index(source, i) + return context.eval_tuple(key) + + while a_i < a_n and b_i < b_n: + key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) + + a_group = [] + + while a_i < a_n and key == get_key(a, a_key, a_i): + a_group.append(context[a].row) + a_i += 1 + + b_group = [] + + while b_i < b_n and key == get_key(b, b_key, b_i): + b_group.append(context[b].row) + b_i += 1 + + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def aggregate(self, step, context): + source = step.source + group_by = self.generate_tuple(step.group) + aggregations = self.generate_tuple(step.aggregations) + operands = self.generate_tuple(step.operands) + + context.sort(source, group_by) + + if step.operands: + source_table = context.tables[source] + operand_table = Table( + source_table.columns + self.table(step.operands).columns + ) + + for reader, ctx in context: + operand_table.append(reader.row + ctx.eval_tuple(operands)) + + context = self.context({source: operand_table}) + + group = None + start = 0 + end = 1 + length = len(context.tables[source]) + table = self.table(step.group + step.aggregations) + + for i in range(length): + context.set_index(source, i) + key = context.eval_tuple(group_by) + group = key if group is None else group + end += 1 + + if i == length - 1: + context.set_range(source, start, end - 1) + elif key != group: + context.set_range(source, start, end - 2) + else: + continue + + table.append(group + context.eval_tuple(aggregations)) + group = key + start = end - 2 + + return self.scan(step, self.context({source: table})) + + def sort(self, step, context): + table = list(context.tables)[0] + key = self.generate_tuple(step.key) + context.sort(table, key) + return self.scan(step, context) + + +def _cast_py(self, expression): + to = expression.args["to"].this + this = self.sql(expression, "this") + + if to == exp.DataType.Type.DATE: + return f"datetime.date.fromisoformat({this})" + if to == exp.DataType.Type.TEXT: + return f"str({this})" + raise NotImplementedError + + +def _column_py(self, expression): + table = self.sql(expression, "table") + this = self.sql(expression, "this") + return f"scope[{table}][{this}]" + + +def _interval_py(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() + if unit == "DAY": + return f"datetime.timedelta(days=float({this}))" + raise NotImplementedError + + +def _like_py(self, expression): + this = self.sql(expression, "this") + expression = self.sql(expression, "expression") + return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})""" + + +def _ordered_py(self, expression): + this = self.sql(expression, "this") + desc = expression.args.get("desc") + return f"desc({this})" if desc else this + + +class Python(Dialect): + class Tokenizer(Tokenizer): + ESCAPE = "\\" + + class Generator(Generator): + TRANSFORMS = { + exp.Alias: lambda self, e: self.sql(e.this), + exp.Array: inline_array_sql, + exp.And: lambda self, e: self.binary(e, "and"), + exp.Cast: _cast_py, + exp.Column: _column_py, + exp.EQ: lambda self, e: self.binary(e, "=="), + exp.Interval: _interval_py, + exp.Is: lambda self, e: self.binary(e, "is"), + exp.Like: _like_py, + exp.Not: lambda self, e: f"not {self.sql(e.this)}", + exp.Null: lambda *_: "None", + exp.Or: lambda self, e: self.binary(e, "or"), + exp.Ordered: _ordered_py, + exp.Star: lambda *_: "1", + } + + def case_sql(self, expression): + this = self.sql(expression, "this") + chain = self.sql(expression, "default") or "None" + + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" + + return chain diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py new file mode 100644 index 0000000..6df49f7 --- /dev/null +++ b/sqlglot/executor/table.py @@ -0,0 +1,81 @@ +class Table: + def __init__(self, *columns, rows=None): + self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + self.rows = rows or [] + if rows: + assert len(rows[0]) == len(self.columns) + self.reader = RowReader(self.columns) + self.range_reader = RangeReader(self) + + def append(self, row): + assert len(row) == len(self.columns) + self.rows.append(row) + + def pop(self): + self.rows.pop() + + @property + def width(self): + return len(self.columns) + + def __len__(self): + return len(self.rows) + + def __iter__(self): + return TableIter(self) + + def __getitem__(self, index): + self.reader.row = self.rows[index] + return self.reader + + def __repr__(self): + widths = {column: len(column) for column in self.columns} + lines = [" ".join(column for column in self.columns)] + + for i, row in enumerate(self): + if i > 10: + break + + lines.append( + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] + for column in self.columns + ) + ) + return "\n".join(lines) + + +class TableIter: + def __init__(self, table): + self.table = table + self.index = -1 + + def __iter__(self): + return self + + def __next__(self): + self.index += 1 + if self.index < len(self.table): + return self.table[self.index] + raise StopIteration + + +class RangeReader: + def __init__(self, table): + self.table = table + self.range = range(0) + + def __len__(self): + return len(self.range) + + def __getitem__(self, column): + return (self.table[i][column] for i in self.range) + + +class RowReader: + def __init__(self, columns): + self.columns = {column: i for i, column in enumerate(columns)} + self.row = None + + def __getitem__(self, column): + return self.row[self.columns[column]] |