Edit on GitHub

sqlglot.executor.python

  1import ast
  2import collections
  3import itertools
  4import math
  5
  6from sqlglot import exp, generator, planner, tokens
  7from sqlglot.dialects.dialect import Dialect, inline_array_sql
  8from sqlglot.errors import ExecuteError
  9from sqlglot.executor.context import Context
 10from sqlglot.executor.env import ENV
 11from sqlglot.executor.table import RowReader, Table
 12from sqlglot.helper import csv_reader, subclasses
 13
 14
 15class PythonExecutor:
 16    def __init__(self, env=None, tables=None):
 17        self.generator = Python().generator(identify=True, comments=False)
 18        self.env = {**ENV, **(env or {})}
 19        self.tables = tables or {}
 20
 21    def execute(self, plan):
 22        running = set()
 23        finished = set()
 24        queue = set(plan.leaves)
 25        contexts = {}
 26
 27        while queue:
 28            node = queue.pop()
 29            try:
 30                context = self.context(
 31                    {
 32                        name: table
 33                        for dep in node.dependencies
 34                        for name, table in contexts[dep].tables.items()
 35                    }
 36                )
 37                running.add(node)
 38
 39                if isinstance(node, planner.Scan):
 40                    contexts[node] = self.scan(node, context)
 41                elif isinstance(node, planner.Aggregate):
 42                    contexts[node] = self.aggregate(node, context)
 43                elif isinstance(node, planner.Join):
 44                    contexts[node] = self.join(node, context)
 45                elif isinstance(node, planner.Sort):
 46                    contexts[node] = self.sort(node, context)
 47                elif isinstance(node, planner.SetOperation):
 48                    contexts[node] = self.set_operation(node, context)
 49                else:
 50                    raise NotImplementedError
 51
 52                running.remove(node)
 53                finished.add(node)
 54
 55                for dep in node.dependents:
 56                    if dep not in running and all(d in contexts for d in dep.dependencies):
 57                        queue.add(dep)
 58
 59                for dep in node.dependencies:
 60                    if all(d in finished for d in dep.dependents):
 61                        contexts.pop(dep)
 62            except Exception as e:
 63                raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
 64
 65        root = plan.root
 66        return contexts[root].tables[root.name]
 67
 68    def generate(self, expression):
 69        """Convert a SQL expression into literal Python code and compile it into bytecode."""
 70        if not expression:
 71            return None
 72
 73        sql = self.generator.generate(expression)
 74        return compile(sql, sql, "eval", optimize=2)
 75
 76    def generate_tuple(self, expressions):
 77        """Convert an array of SQL expressions into tuple of Python byte code."""
 78        if not expressions:
 79            return tuple()
 80        return tuple(self.generate(expression) for expression in expressions)
 81
 82    def context(self, tables):
 83        return Context(tables, env=self.env)
 84
 85    def table(self, expressions):
 86        return Table(
 87            expression.alias_or_name if isinstance(expression, exp.Expression) else expression
 88            for expression in expressions
 89        )
 90
 91    def scan(self, step, context):
 92        source = step.source
 93
 94        if source and isinstance(source, exp.Expression):
 95            source = source.name or source.alias
 96
 97        condition = self.generate(step.condition)
 98        projections = self.generate_tuple(step.projections)
 99
100        if source is None:
101            context, table_iter = self.static()
102        elif source in context:
103            if not projections and not condition:
104                return self.context({step.name: context.tables[source]})
105            table_iter = context.table_iter(source)
106        elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
107            table_iter = self.scan_csv(step)
108            context = next(table_iter)
109        else:
110            context, table_iter = self.scan_table(step)
111
112        if projections:
113            sink = self.table(step.projections)
114        else:
115            sink = self.table(context.columns)
116
117        for reader in table_iter:
118            if len(sink) >= step.limit:
119                break
120
121            if condition and not context.eval(condition):
122                continue
123
124            if projections:
125                sink.append(context.eval_tuple(projections))
126            else:
127                sink.append(reader.row)
128
129        return self.context({step.name: sink})
130
131    def static(self):
132        return self.context({}), [RowReader(())]
133
134    def scan_table(self, step):
135        table = self.tables.find(step.source)
136        context = self.context({step.source.alias_or_name: table})
137        return context, iter(table)
138
139    def scan_csv(self, step):
140        alias = step.source.alias
141        source = step.source.this
142
143        with csv_reader(source) as reader:
144            columns = next(reader)
145            table = Table(columns)
146            context = self.context({alias: table})
147            yield context
148            types = []
149
150            for row in reader:
151                if not types:
152                    for v in row:
153                        try:
154                            types.append(type(ast.literal_eval(v)))
155                        except (ValueError, SyntaxError):
156                            types.append(str)
157                context.set_row(tuple(t(v) for t, v in zip(types, row)))
158                yield context.table.reader
159
160    def join(self, step, context):
161        source = step.name
162
163        source_table = context.tables[source]
164        source_context = self.context({source: source_table})
165        column_ranges = {source: range(0, len(source_table.columns))}
166
167        for name, join in step.joins.items():
168            table = context.tables[name]
169            start = max(r.stop for r in column_ranges.values())
170            column_ranges[name] = range(start, len(table.columns) + start)
171            join_context = self.context({name: table})
172
173            if join.get("source_key"):
174                table = self.hash_join(join, source_context, join_context)
175            else:
176                table = self.nested_loop_join(join, source_context, join_context)
177
178            source_context = self.context(
179                {
180                    name: Table(table.columns, table.rows, column_range)
181                    for name, column_range in column_ranges.items()
182                }
183            )
184            condition = self.generate(join["condition"])
185            if condition:
186                source_context.filter(condition)
187
188        condition = self.generate(step.condition)
189        projections = self.generate_tuple(step.projections)
190
191        if not condition and not projections:
192            return source_context
193
194        sink = self.table(step.projections if projections else source_context.columns)
195
196        for reader, ctx in source_context:
197            if condition and not ctx.eval(condition):
198                continue
199
200            if projections:
201                sink.append(ctx.eval_tuple(projections))
202            else:
203                sink.append(reader.row)
204
205            if len(sink) >= step.limit:
206                break
207
208        if projections:
209            return self.context({step.name: sink})
210        else:
211            return self.context(
212                {
213                    name: Table(table.columns, sink.rows, table.column_range)
214                    for name, table in source_context.tables.items()
215                }
216            )
217
218    def nested_loop_join(self, _join, source_context, join_context):
219        table = Table(source_context.columns + join_context.columns)
220
221        for reader_a, _ in source_context:
222            for reader_b, _ in join_context:
223                table.append(reader_a.row + reader_b.row)
224
225        return table
226
227    def hash_join(self, join, source_context, join_context):
228        source_key = self.generate_tuple(join["source_key"])
229        join_key = self.generate_tuple(join["join_key"])
230        left = join.get("side") == "LEFT"
231        right = join.get("side") == "RIGHT"
232
233        results = collections.defaultdict(lambda: ([], []))
234
235        for reader, ctx in source_context:
236            results[ctx.eval_tuple(source_key)][0].append(reader.row)
237        for reader, ctx in join_context:
238            results[ctx.eval_tuple(join_key)][1].append(reader.row)
239
240        table = Table(source_context.columns + join_context.columns)
241        nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
242
243        for a_group, b_group in results.values():
244            if left:
245                b_group = b_group or nulls
246            elif right:
247                a_group = a_group or nulls
248
249            for a_row, b_row in itertools.product(a_group, b_group):
250                table.append(a_row + b_row)
251
252        return table
253
254    def aggregate(self, step, context):
255        group_by = self.generate_tuple(step.group.values())
256        aggregations = self.generate_tuple(step.aggregations)
257        operands = self.generate_tuple(step.operands)
258
259        if operands:
260            operand_table = Table(self.table(step.operands).columns)
261
262            for reader, ctx in context:
263                operand_table.append(ctx.eval_tuple(operands))
264
265            for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
266                context.table.rows[i] = a + b
267
268            width = len(context.columns)
269            context.add_columns(*operand_table.columns)
270
271            operand_table = Table(
272                context.columns,
273                context.table.rows,
274                range(width, width + len(operand_table.columns)),
275            )
276
277            context = self.context(
278                {
279                    None: operand_table,
280                    **context.tables,
281                }
282            )
283
284        context.sort(group_by)
285
286        group = None
287        start = 0
288        end = 1
289        length = len(context.table)
290        table = self.table(list(step.group) + step.aggregations)
291        condition = self.generate(step.condition)
292
293        def add_row():
294            if not condition or context.eval(condition):
295                table.append(group + context.eval_tuple(aggregations))
296
297        if length:
298            for i in range(length):
299                context.set_index(i)
300                key = context.eval_tuple(group_by)
301                group = key if group is None else group
302                end += 1
303                if key != group:
304                    context.set_range(start, end - 2)
305                    add_row()
306                    group = key
307                    start = end - 2
308                if len(table.rows) >= step.limit:
309                    break
310                if i == length - 1:
311                    context.set_range(start, end - 1)
312                    add_row()
313        elif step.limit > 0 and not group_by:
314            context.set_range(0, 0)
315            table.append(context.eval_tuple(aggregations))
316
317        context = self.context({step.name: table, **{name: table for name in context.tables}})
318
319        if step.projections:
320            return self.scan(step, context)
321        return context
322
323    def sort(self, step, context):
324        projections = self.generate_tuple(step.projections)
325        projection_columns = [p.alias_or_name for p in step.projections]
326        all_columns = list(context.columns) + projection_columns
327        sink = self.table(all_columns)
328        for reader, ctx in context:
329            sink.append(reader.row + ctx.eval_tuple(projections))
330
331        sort_ctx = self.context(
332            {
333                None: sink,
334                **{table: sink for table in context.tables},
335            }
336        )
337        sort_ctx.sort(self.generate_tuple(step.key))
338
339        if not math.isinf(step.limit):
340            sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
341
342        output = Table(
343            projection_columns,
344            rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
345        )
346        return self.context({step.name: output})
347
348    def set_operation(self, step, context):
349        left = context.tables[step.left]
350        right = context.tables[step.right]
351
352        sink = self.table(left.columns)
353
354        if issubclass(step.op, exp.Intersect):
355            sink.rows = list(set(left.rows).intersection(set(right.rows)))
356        elif issubclass(step.op, exp.Except):
357            sink.rows = list(set(left.rows).difference(set(right.rows)))
358        elif issubclass(step.op, exp.Union) and step.distinct:
359            sink.rows = list(set(left.rows).union(set(right.rows)))
360        else:
361            sink.rows = left.rows + right.rows
362
363        return self.context({step.name: sink})
364
365
366def _ordered_py(self, expression):
367    this = self.sql(expression, "this")
368    desc = "True" if expression.args.get("desc") else "False"
369    nulls_first = "True" if expression.args.get("nulls_first") else "False"
370    return f"ORDERED({this}, {desc}, {nulls_first})"
371
372
373def _rename(self, e):
374    try:
375        if "expressions" in e.args:
376            this = self.sql(e, "this")
377            this = f"{this}, " if this else ""
378            return f"{e.key.upper()}({this}{self.expressions(e)})"
379        return self.func(e.key, *e.args.values())
380    except Exception as ex:
381        raise Exception(f"Could not rename {repr(e)}") from ex
382
383
384def _case_sql(self, expression):
385    this = self.sql(expression, "this")
386    chain = self.sql(expression, "default") or "None"
387
388    for e in reversed(expression.args["ifs"]):
389        true = self.sql(e, "true")
390        condition = self.sql(e, "this")
391        condition = f"{this} = ({condition})" if this else condition
392        chain = f"{true} if {condition} else ({chain})"
393
394    return chain
395
396
397def _lambda_sql(self, e: exp.Lambda) -> str:
398    names = {e.name.lower() for e in e.expressions}
399
400    e = e.transform(
401        lambda n: exp.Var(this=n.name)
402        if isinstance(n, exp.Identifier) and n.name.lower() in names
403        else n
404    )
405
406    return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
407
408
409class Python(Dialect):
410    class Tokenizer(tokens.Tokenizer):
411        STRING_ESCAPES = ["\\"]
412
413    class Generator(generator.Generator):
414        TRANSFORMS = {
415            **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
416            **{klass: _rename for klass in exp.ALL_FUNCTIONS},
417            exp.Case: _case_sql,
418            exp.Alias: lambda self, e: self.sql(e.this),
419            exp.Array: inline_array_sql,
420            exp.And: lambda self, e: self.binary(e, "and"),
421            exp.Between: _rename,
422            exp.Boolean: lambda self, e: "True" if e.this else "False",
423            exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
424            exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
425            exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
426            exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
427            exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
428            exp.Is: lambda self, e: self.binary(e, "is"),
429            exp.Lambda: _lambda_sql,
430            exp.Not: lambda self, e: f"not {self.sql(e.this)}",
431            exp.Null: lambda *_: "None",
432            exp.Or: lambda self, e: self.binary(e, "or"),
433            exp.Ordered: _ordered_py,
434            exp.Star: lambda *_: "1",
435        }
class PythonExecutor:
 16class PythonExecutor:
 17    def __init__(self, env=None, tables=None):
 18        self.generator = Python().generator(identify=True, comments=False)
 19        self.env = {**ENV, **(env or {})}
 20        self.tables = tables or {}
 21
 22    def execute(self, plan):
 23        running = set()
 24        finished = set()
 25        queue = set(plan.leaves)
 26        contexts = {}
 27
 28        while queue:
 29            node = queue.pop()
 30            try:
 31                context = self.context(
 32                    {
 33                        name: table
 34                        for dep in node.dependencies
 35                        for name, table in contexts[dep].tables.items()
 36                    }
 37                )
 38                running.add(node)
 39
 40                if isinstance(node, planner.Scan):
 41                    contexts[node] = self.scan(node, context)
 42                elif isinstance(node, planner.Aggregate):
 43                    contexts[node] = self.aggregate(node, context)
 44                elif isinstance(node, planner.Join):
 45                    contexts[node] = self.join(node, context)
 46                elif isinstance(node, planner.Sort):
 47                    contexts[node] = self.sort(node, context)
 48                elif isinstance(node, planner.SetOperation):
 49                    contexts[node] = self.set_operation(node, context)
 50                else:
 51                    raise NotImplementedError
 52
 53                running.remove(node)
 54                finished.add(node)
 55
 56                for dep in node.dependents:
 57                    if dep not in running and all(d in contexts for d in dep.dependencies):
 58                        queue.add(dep)
 59
 60                for dep in node.dependencies:
 61                    if all(d in finished for d in dep.dependents):
 62                        contexts.pop(dep)
 63            except Exception as e:
 64                raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
 65
 66        root = plan.root
 67        return contexts[root].tables[root.name]
 68
 69    def generate(self, expression):
 70        """Convert a SQL expression into literal Python code and compile it into bytecode."""
 71        if not expression:
 72            return None
 73
 74        sql = self.generator.generate(expression)
 75        return compile(sql, sql, "eval", optimize=2)
 76
 77    def generate_tuple(self, expressions):
 78        """Convert an array of SQL expressions into tuple of Python byte code."""
 79        if not expressions:
 80            return tuple()
 81        return tuple(self.generate(expression) for expression in expressions)
 82
 83    def context(self, tables):
 84        return Context(tables, env=self.env)
 85
 86    def table(self, expressions):
 87        return Table(
 88            expression.alias_or_name if isinstance(expression, exp.Expression) else expression
 89            for expression in expressions
 90        )
 91
 92    def scan(self, step, context):
 93        source = step.source
 94
 95        if source and isinstance(source, exp.Expression):
 96            source = source.name or source.alias
 97
 98        condition = self.generate(step.condition)
 99        projections = self.generate_tuple(step.projections)
100
101        if source is None:
102            context, table_iter = self.static()
103        elif source in context:
104            if not projections and not condition:
105                return self.context({step.name: context.tables[source]})
106            table_iter = context.table_iter(source)
107        elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
108            table_iter = self.scan_csv(step)
109            context = next(table_iter)
110        else:
111            context, table_iter = self.scan_table(step)
112
113        if projections:
114            sink = self.table(step.projections)
115        else:
116            sink = self.table(context.columns)
117
118        for reader in table_iter:
119            if len(sink) >= step.limit:
120                break
121
122            if condition and not context.eval(condition):
123                continue
124
125            if projections:
126                sink.append(context.eval_tuple(projections))
127            else:
128                sink.append(reader.row)
129
130        return self.context({step.name: sink})
131
132    def static(self):
133        return self.context({}), [RowReader(())]
134
135    def scan_table(self, step):
136        table = self.tables.find(step.source)
137        context = self.context({step.source.alias_or_name: table})
138        return context, iter(table)
139
140    def scan_csv(self, step):
141        alias = step.source.alias
142        source = step.source.this
143
144        with csv_reader(source) as reader:
145            columns = next(reader)
146            table = Table(columns)
147            context = self.context({alias: table})
148            yield context
149            types = []
150
151            for row in reader:
152                if not types:
153                    for v in row:
154                        try:
155                            types.append(type(ast.literal_eval(v)))
156                        except (ValueError, SyntaxError):
157                            types.append(str)
158                context.set_row(tuple(t(v) for t, v in zip(types, row)))
159                yield context.table.reader
160
161    def join(self, step, context):
162        source = step.name
163
164        source_table = context.tables[source]
165        source_context = self.context({source: source_table})
166        column_ranges = {source: range(0, len(source_table.columns))}
167
168        for name, join in step.joins.items():
169            table = context.tables[name]
170            start = max(r.stop for r in column_ranges.values())
171            column_ranges[name] = range(start, len(table.columns) + start)
172            join_context = self.context({name: table})
173
174            if join.get("source_key"):
175                table = self.hash_join(join, source_context, join_context)
176            else:
177                table = self.nested_loop_join(join, source_context, join_context)
178
179            source_context = self.context(
180                {
181                    name: Table(table.columns, table.rows, column_range)
182                    for name, column_range in column_ranges.items()
183                }
184            )
185            condition = self.generate(join["condition"])
186            if condition:
187                source_context.filter(condition)
188
189        condition = self.generate(step.condition)
190        projections = self.generate_tuple(step.projections)
191
192        if not condition and not projections:
193            return source_context
194
195        sink = self.table(step.projections if projections else source_context.columns)
196
197        for reader, ctx in source_context:
198            if condition and not ctx.eval(condition):
199                continue
200
201            if projections:
202                sink.append(ctx.eval_tuple(projections))
203            else:
204                sink.append(reader.row)
205
206            if len(sink) >= step.limit:
207                break
208
209        if projections:
210            return self.context({step.name: sink})
211        else:
212            return self.context(
213                {
214                    name: Table(table.columns, sink.rows, table.column_range)
215                    for name, table in source_context.tables.items()
216                }
217            )
218
219    def nested_loop_join(self, _join, source_context, join_context):
220        table = Table(source_context.columns + join_context.columns)
221
222        for reader_a, _ in source_context:
223            for reader_b, _ in join_context:
224                table.append(reader_a.row + reader_b.row)
225
226        return table
227
228    def hash_join(self, join, source_context, join_context):
229        source_key = self.generate_tuple(join["source_key"])
230        join_key = self.generate_tuple(join["join_key"])
231        left = join.get("side") == "LEFT"
232        right = join.get("side") == "RIGHT"
233
234        results = collections.defaultdict(lambda: ([], []))
235
236        for reader, ctx in source_context:
237            results[ctx.eval_tuple(source_key)][0].append(reader.row)
238        for reader, ctx in join_context:
239            results[ctx.eval_tuple(join_key)][1].append(reader.row)
240
241        table = Table(source_context.columns + join_context.columns)
242        nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
243
244        for a_group, b_group in results.values():
245            if left:
246                b_group = b_group or nulls
247            elif right:
248                a_group = a_group or nulls
249
250            for a_row, b_row in itertools.product(a_group, b_group):
251                table.append(a_row + b_row)
252
253        return table
254
255    def aggregate(self, step, context):
256        group_by = self.generate_tuple(step.group.values())
257        aggregations = self.generate_tuple(step.aggregations)
258        operands = self.generate_tuple(step.operands)
259
260        if operands:
261            operand_table = Table(self.table(step.operands).columns)
262
263            for reader, ctx in context:
264                operand_table.append(ctx.eval_tuple(operands))
265
266            for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
267                context.table.rows[i] = a + b
268
269            width = len(context.columns)
270            context.add_columns(*operand_table.columns)
271
272            operand_table = Table(
273                context.columns,
274                context.table.rows,
275                range(width, width + len(operand_table.columns)),
276            )
277
278            context = self.context(
279                {
280                    None: operand_table,
281                    **context.tables,
282                }
283            )
284
285        context.sort(group_by)
286
287        group = None
288        start = 0
289        end = 1
290        length = len(context.table)
291        table = self.table(list(step.group) + step.aggregations)
292        condition = self.generate(step.condition)
293
294        def add_row():
295            if not condition or context.eval(condition):
296                table.append(group + context.eval_tuple(aggregations))
297
298        if length:
299            for i in range(length):
300                context.set_index(i)
301                key = context.eval_tuple(group_by)
302                group = key if group is None else group
303                end += 1
304                if key != group:
305                    context.set_range(start, end - 2)
306                    add_row()
307                    group = key
308                    start = end - 2
309                if len(table.rows) >= step.limit:
310                    break
311                if i == length - 1:
312                    context.set_range(start, end - 1)
313                    add_row()
314        elif step.limit > 0 and not group_by:
315            context.set_range(0, 0)
316            table.append(context.eval_tuple(aggregations))
317
318        context = self.context({step.name: table, **{name: table for name in context.tables}})
319
320        if step.projections:
321            return self.scan(step, context)
322        return context
323
324    def sort(self, step, context):
325        projections = self.generate_tuple(step.projections)
326        projection_columns = [p.alias_or_name for p in step.projections]
327        all_columns = list(context.columns) + projection_columns
328        sink = self.table(all_columns)
329        for reader, ctx in context:
330            sink.append(reader.row + ctx.eval_tuple(projections))
331
332        sort_ctx = self.context(
333            {
334                None: sink,
335                **{table: sink for table in context.tables},
336            }
337        )
338        sort_ctx.sort(self.generate_tuple(step.key))
339
340        if not math.isinf(step.limit):
341            sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
342
343        output = Table(
344            projection_columns,
345            rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
346        )
347        return self.context({step.name: output})
348
349    def set_operation(self, step, context):
350        left = context.tables[step.left]
351        right = context.tables[step.right]
352
353        sink = self.table(left.columns)
354
355        if issubclass(step.op, exp.Intersect):
356            sink.rows = list(set(left.rows).intersection(set(right.rows)))
357        elif issubclass(step.op, exp.Except):
358            sink.rows = list(set(left.rows).difference(set(right.rows)))
359        elif issubclass(step.op, exp.Union) and step.distinct:
360            sink.rows = list(set(left.rows).union(set(right.rows)))
361        else:
362            sink.rows = left.rows + right.rows
363
364        return self.context({step.name: sink})
PythonExecutor(env=None, tables=None)
17    def __init__(self, env=None, tables=None):
18        self.generator = Python().generator(identify=True, comments=False)
19        self.env = {**ENV, **(env or {})}
20        self.tables = tables or {}
def execute(self, plan):
22    def execute(self, plan):
23        running = set()
24        finished = set()
25        queue = set(plan.leaves)
26        contexts = {}
27
28        while queue:
29            node = queue.pop()
30            try:
31                context = self.context(
32                    {
33                        name: table
34                        for dep in node.dependencies
35                        for name, table in contexts[dep].tables.items()
36                    }
37                )
38                running.add(node)
39
40                if isinstance(node, planner.Scan):
41                    contexts[node] = self.scan(node, context)
42                elif isinstance(node, planner.Aggregate):
43                    contexts[node] = self.aggregate(node, context)
44                elif isinstance(node, planner.Join):
45                    contexts[node] = self.join(node, context)
46                elif isinstance(node, planner.Sort):
47                    contexts[node] = self.sort(node, context)
48                elif isinstance(node, planner.SetOperation):
49                    contexts[node] = self.set_operation(node, context)
50                else:
51                    raise NotImplementedError
52
53                running.remove(node)
54                finished.add(node)
55
56                for dep in node.dependents:
57                    if dep not in running and all(d in contexts for d in dep.dependencies):
58                        queue.add(dep)
59
60                for dep in node.dependencies:
61                    if all(d in finished for d in dep.dependents):
62                        contexts.pop(dep)
63            except Exception as e:
64                raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
65
66        root = plan.root
67        return contexts[root].tables[root.name]
def generate(self, expression):
69    def generate(self, expression):
70        """Convert a SQL expression into literal Python code and compile it into bytecode."""
71        if not expression:
72            return None
73
74        sql = self.generator.generate(expression)
75        return compile(sql, sql, "eval", optimize=2)

Convert a SQL expression into literal Python code and compile it into bytecode.

def generate_tuple(self, expressions):
77    def generate_tuple(self, expressions):
78        """Convert an array of SQL expressions into tuple of Python byte code."""
79        if not expressions:
80            return tuple()
81        return tuple(self.generate(expression) for expression in expressions)

Convert an array of SQL expressions into tuple of Python byte code.

def context(self, tables):
83    def context(self, tables):
84        return Context(tables, env=self.env)
def table(self, expressions):
86    def table(self, expressions):
87        return Table(
88            expression.alias_or_name if isinstance(expression, exp.Expression) else expression
89            for expression in expressions
90        )
def scan(self, step, context):
 92    def scan(self, step, context):
 93        source = step.source
 94
 95        if source and isinstance(source, exp.Expression):
 96            source = source.name or source.alias
 97
 98        condition = self.generate(step.condition)
 99        projections = self.generate_tuple(step.projections)
100
101        if source is None:
102            context, table_iter = self.static()
103        elif source in context:
104            if not projections and not condition:
105                return self.context({step.name: context.tables[source]})
106            table_iter = context.table_iter(source)
107        elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
108            table_iter = self.scan_csv(step)
109            context = next(table_iter)
110        else:
111            context, table_iter = self.scan_table(step)
112
113        if projections:
114            sink = self.table(step.projections)
115        else:
116            sink = self.table(context.columns)
117
118        for reader in table_iter:
119            if len(sink) >= step.limit:
120                break
121
122            if condition and not context.eval(condition):
123                continue
124
125            if projections:
126                sink.append(context.eval_tuple(projections))
127            else:
128                sink.append(reader.row)
129
130        return self.context({step.name: sink})
def static(self):
132    def static(self):
133        return self.context({}), [RowReader(())]
def scan_table(self, step):
135    def scan_table(self, step):
136        table = self.tables.find(step.source)
137        context = self.context({step.source.alias_or_name: table})
138        return context, iter(table)
def scan_csv(self, step):
140    def scan_csv(self, step):
141        alias = step.source.alias
142        source = step.source.this
143
144        with csv_reader(source) as reader:
145            columns = next(reader)
146            table = Table(columns)
147            context = self.context({alias: table})
148            yield context
149            types = []
150
151            for row in reader:
152                if not types:
153                    for v in row:
154                        try:
155                            types.append(type(ast.literal_eval(v)))
156                        except (ValueError, SyntaxError):
157                            types.append(str)
158                context.set_row(tuple(t(v) for t, v in zip(types, row)))
159                yield context.table.reader
def join(self, step, context):
161    def join(self, step, context):
162        source = step.name
163
164        source_table = context.tables[source]
165        source_context = self.context({source: source_table})
166        column_ranges = {source: range(0, len(source_table.columns))}
167
168        for name, join in step.joins.items():
169            table = context.tables[name]
170            start = max(r.stop for r in column_ranges.values())
171            column_ranges[name] = range(start, len(table.columns) + start)
172            join_context = self.context({name: table})
173
174            if join.get("source_key"):
175                table = self.hash_join(join, source_context, join_context)
176            else:
177                table = self.nested_loop_join(join, source_context, join_context)
178
179            source_context = self.context(
180                {
181                    name: Table(table.columns, table.rows, column_range)
182                    for name, column_range in column_ranges.items()
183                }
184            )
185            condition = self.generate(join["condition"])
186            if condition:
187                source_context.filter(condition)
188
189        condition = self.generate(step.condition)
190        projections = self.generate_tuple(step.projections)
191
192        if not condition and not projections:
193            return source_context
194
195        sink = self.table(step.projections if projections else source_context.columns)
196
197        for reader, ctx in source_context:
198            if condition and not ctx.eval(condition):
199                continue
200
201            if projections:
202                sink.append(ctx.eval_tuple(projections))
203            else:
204                sink.append(reader.row)
205
206            if len(sink) >= step.limit:
207                break
208
209        if projections:
210            return self.context({step.name: sink})
211        else:
212            return self.context(
213                {
214                    name: Table(table.columns, sink.rows, table.column_range)
215                    for name, table in source_context.tables.items()
216                }
217            )
def nested_loop_join(self, _join, source_context, join_context):
219    def nested_loop_join(self, _join, source_context, join_context):
220        table = Table(source_context.columns + join_context.columns)
221
222        for reader_a, _ in source_context:
223            for reader_b, _ in join_context:
224                table.append(reader_a.row + reader_b.row)
225
226        return table
def hash_join(self, join, source_context, join_context):
228    def hash_join(self, join, source_context, join_context):
229        source_key = self.generate_tuple(join["source_key"])
230        join_key = self.generate_tuple(join["join_key"])
231        left = join.get("side") == "LEFT"
232        right = join.get("side") == "RIGHT"
233
234        results = collections.defaultdict(lambda: ([], []))
235
236        for reader, ctx in source_context:
237            results[ctx.eval_tuple(source_key)][0].append(reader.row)
238        for reader, ctx in join_context:
239            results[ctx.eval_tuple(join_key)][1].append(reader.row)
240
241        table = Table(source_context.columns + join_context.columns)
242        nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
243
244        for a_group, b_group in results.values():
245            if left:
246                b_group = b_group or nulls
247            elif right:
248                a_group = a_group or nulls
249
250            for a_row, b_row in itertools.product(a_group, b_group):
251                table.append(a_row + b_row)
252
253        return table
def aggregate(self, step, context):
255    def aggregate(self, step, context):
256        group_by = self.generate_tuple(step.group.values())
257        aggregations = self.generate_tuple(step.aggregations)
258        operands = self.generate_tuple(step.operands)
259
260        if operands:
261            operand_table = Table(self.table(step.operands).columns)
262
263            for reader, ctx in context:
264                operand_table.append(ctx.eval_tuple(operands))
265
266            for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
267                context.table.rows[i] = a + b
268
269            width = len(context.columns)
270            context.add_columns(*operand_table.columns)
271
272            operand_table = Table(
273                context.columns,
274                context.table.rows,
275                range(width, width + len(operand_table.columns)),
276            )
277
278            context = self.context(
279                {
280                    None: operand_table,
281                    **context.tables,
282                }
283            )
284
285        context.sort(group_by)
286
287        group = None
288        start = 0
289        end = 1
290        length = len(context.table)
291        table = self.table(list(step.group) + step.aggregations)
292        condition = self.generate(step.condition)
293
294        def add_row():
295            if not condition or context.eval(condition):
296                table.append(group + context.eval_tuple(aggregations))
297
298        if length:
299            for i in range(length):
300                context.set_index(i)
301                key = context.eval_tuple(group_by)
302                group = key if group is None else group
303                end += 1
304                if key != group:
305                    context.set_range(start, end - 2)
306                    add_row()
307                    group = key
308                    start = end - 2
309                if len(table.rows) >= step.limit:
310                    break
311                if i == length - 1:
312                    context.set_range(start, end - 1)
313                    add_row()
314        elif step.limit > 0 and not group_by:
315            context.set_range(0, 0)
316            table.append(context.eval_tuple(aggregations))
317
318        context = self.context({step.name: table, **{name: table for name in context.tables}})
319
320        if step.projections:
321            return self.scan(step, context)
322        return context
def sort(self, step, context):
324    def sort(self, step, context):
325        projections = self.generate_tuple(step.projections)
326        projection_columns = [p.alias_or_name for p in step.projections]
327        all_columns = list(context.columns) + projection_columns
328        sink = self.table(all_columns)
329        for reader, ctx in context:
330            sink.append(reader.row + ctx.eval_tuple(projections))
331
332        sort_ctx = self.context(
333            {
334                None: sink,
335                **{table: sink for table in context.tables},
336            }
337        )
338        sort_ctx.sort(self.generate_tuple(step.key))
339
340        if not math.isinf(step.limit):
341            sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
342
343        output = Table(
344            projection_columns,
345            rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
346        )
347        return self.context({step.name: output})
def set_operation(self, step, context):
349    def set_operation(self, step, context):
350        left = context.tables[step.left]
351        right = context.tables[step.right]
352
353        sink = self.table(left.columns)
354
355        if issubclass(step.op, exp.Intersect):
356            sink.rows = list(set(left.rows).intersection(set(right.rows)))
357        elif issubclass(step.op, exp.Except):
358            sink.rows = list(set(left.rows).difference(set(right.rows)))
359        elif issubclass(step.op, exp.Union) and step.distinct:
360            sink.rows = list(set(left.rows).union(set(right.rows)))
361        else:
362            sink.rows = left.rows + right.rows
363
364        return self.context({step.name: sink})
class Python(sqlglot.dialects.dialect.Dialect):
410class Python(Dialect):
411    class Tokenizer(tokens.Tokenizer):
412        STRING_ESCAPES = ["\\"]
413
414    class Generator(generator.Generator):
415        TRANSFORMS = {
416            **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
417            **{klass: _rename for klass in exp.ALL_FUNCTIONS},
418            exp.Case: _case_sql,
419            exp.Alias: lambda self, e: self.sql(e.this),
420            exp.Array: inline_array_sql,
421            exp.And: lambda self, e: self.binary(e, "and"),
422            exp.Between: _rename,
423            exp.Boolean: lambda self, e: "True" if e.this else "False",
424            exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
425            exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
426            exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
427            exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
428            exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
429            exp.Is: lambda self, e: self.binary(e, "is"),
430            exp.Lambda: _lambda_sql,
431            exp.Not: lambda self, e: f"not {self.sql(e.this)}",
432            exp.Null: lambda *_: "None",
433            exp.Or: lambda self, e: self.binary(e, "or"),
434            exp.Ordered: _ordered_py,
435            exp.Star: lambda *_: "1",
436        }
class Python.Tokenizer(sqlglot.tokens.Tokenizer):
411    class Tokenizer(tokens.Tokenizer):
412        STRING_ESCAPES = ["\\"]
class Python.Generator(sqlglot.generator.Generator):
414    class Generator(generator.Generator):
415        TRANSFORMS = {
416            **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
417            **{klass: _rename for klass in exp.ALL_FUNCTIONS},
418            exp.Case: _case_sql,
419            exp.Alias: lambda self, e: self.sql(e.this),
420            exp.Array: inline_array_sql,
421            exp.And: lambda self, e: self.binary(e, "and"),
422            exp.Between: _rename,
423            exp.Boolean: lambda self, e: "True" if e.this else "False",
424            exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
425            exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
426            exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
427            exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
428            exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
429            exp.Is: lambda self, e: self.binary(e, "is"),
430            exp.Lambda: _lambda_sql,
431            exp.Not: lambda self, e: f"not {self.sql(e.this)}",
432            exp.Null: lambda *_: "None",
433            exp.Or: lambda self, e: self.binary(e, "or"),
434            exp.Ordered: _ordered_py,
435            exp.Star: lambda *_: "1",
436        }

Generator interprets the given syntax tree and produces a SQL string as an output.

Arguments:
  • time_mapping (dict): the dictionary of custom time mappings in which the key represents a python time format and the output the target time format
  • time_trie (trie): a trie of the time_mapping keys
  • pretty (bool): if set to True the returned string will be formatted. Default: False.
  • quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
  • quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
  • identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
  • identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
  • identify (bool): if set to True all identifiers will be delimited by the corresponding character.
  • normalize (bool): if set to True all identifiers will lower cased
  • string_escape (str): specifies a string escape character. Default: '.
  • identifier_escape (str): specifies an identifier escape character. Default: ".
  • pad (int): determines padding in a formatted string. Default: 2.
  • indent (int): determines the size of indentation in a formatted string. Default: 4.
  • unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
  • normalize_functions (str): normalize function names, "upper", "lower", or None Default: "upper"
  • alias_post_tablesample (bool): if the table alias comes after tablesample Default: False
  • unsupported_level (ErrorLevel): determines the generator's behavior when it encounters unsupported expressions. Default ErrorLevel.WARN.
  • null_ordering (str): Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small"
  • max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3
  • leading_comma (bool): if the the comma is leading or trailing in select statements Default: False
  • max_text_width: The max number of characters in a segment before creating new lines in pretty mode. The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80
  • comments: Whether or not to preserve comments in the output SQL code. Default: True
Inherited Members
sqlglot.generator.Generator
Generator
generate
unsupported
sep
seg
pad_comment
maybe_comment
wrap
no_identify
normalize_func
indent
sql
uncache_sql
cache_sql
characterset_sql
column_sql
columndef_sql
columnconstraint_sql
autoincrementcolumnconstraint_sql
compresscolumnconstraint_sql
generatedasidentitycolumnconstraint_sql
notnullcolumnconstraint_sql
primarykeycolumnconstraint_sql
uniquecolumnconstraint_sql
create_sql
describe_sql
prepend_ctes
with_sql
cte_sql
tablealias_sql
bitstring_sql
hexstring_sql
datatype_sql
directory_sql
delete_sql
drop_sql
except_sql
except_op
fetch_sql
filter_sql
hint_sql
index_sql
identifier_sql
national_sql
partition_sql
properties_sql
root_properties
properties
with_properties
locate_properties
property_sql
likeproperty_sql
fallbackproperty_sql
journalproperty_sql
freespaceproperty_sql
afterjournalproperty_sql
checksumproperty_sql
mergeblockratioproperty_sql
datablocksizeproperty_sql
blockcompressionproperty_sql
isolatedloadingproperty_sql
lockingproperty_sql
withdataproperty_sql
insert_sql
intersect_sql
intersect_op
introducer_sql
pseudotype_sql
rowformatdelimitedproperty_sql
table_sql
tablesample_sql
pivot_sql
tuple_sql
update_sql
values_sql
var_sql
into_sql
from_sql
group_sql
having_sql
join_sql
lambda_sql
lateral_sql
limit_sql
offset_sql
lock_sql
literal_sql
loaddata_sql
null_sql
boolean_sql
order_sql
cluster_sql
distribute_sql
sort_sql
ordered_sql
matchrecognize_sql
query_modifiers
select_sql
schema_sql
star_sql
structkwarg_sql
parameter_sql
sessionparameter_sql
placeholder_sql
subquery_sql
qualify_sql
union_sql
union_op
unnest_sql
where_sql
window_sql
partition_by_sql
window_spec_sql
withingroup_sql
between_sql
bracket_sql
all_sql
any_sql
exists_sql
case_sql
constraint_sql
extract_sql
trim_sql
concat_sql
check_sql
foreignkey_sql
primarykey_sql
unique_sql
if_sql
in_sql
in_unnest_op
interval_sql
return_sql
reference_sql
anonymous_sql
paren_sql
neg_sql
not_sql
alias_sql
aliases_sql
attimezone_sql
add_sql
and_sql
connector_sql
bitwiseand_sql
bitwiseleftshift_sql
bitwisenot_sql
bitwiseor_sql
bitwiserightshift_sql
bitwisexor_sql
cast_sql
currentdate_sql
collate_sql
command_sql
comment_sql
transaction_sql
commit_sql
rollback_sql
altercolumn_sql
renametable_sql
altertable_sql
droppartition_sql
addconstraint_sql
distinct_sql
ignorenulls_sql
respectnulls_sql
intdiv_sql
dpipe_sql
div_sql
overlaps_sql
distance_sql
dot_sql
eq_sql
escape_sql
glob_sql
gt_sql
gte_sql
ilike_sql
is_sql
like_sql
similarto_sql
lt_sql
lte_sql
mod_sql
mul_sql
neq_sql
nullsafeeq_sql
nullsafeneq_sql
or_sql
slice_sql
sub_sql
trycast_sql
use_sql
binary
function_fallback_sql
func
format_args
text_width
format_time
expressions
op_expressions
naked_property
set_operation
tag_sql
token_sql
userdefinedfunction_sql
joinhint_sql
kwarg_sql
when_sql
merge_sql