From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 96 + sqlglot/__main__.py | 69 + sqlglot/dialects/__init__.py | 15 + sqlglot/dialects/bigquery.py | 128 + sqlglot/dialects/clickhouse.py | 48 + sqlglot/dialects/dialect.py | 268 +++ sqlglot/dialects/duckdb.py | 156 ++ sqlglot/dialects/hive.py | 304 +++ sqlglot/dialects/mysql.py | 163 ++ sqlglot/dialects/oracle.py | 63 + sqlglot/dialects/postgres.py | 109 + sqlglot/dialects/presto.py | 216 ++ sqlglot/dialects/snowflake.py | 145 ++ sqlglot/dialects/spark.py | 106 + sqlglot/dialects/sqlite.py | 63 + sqlglot/dialects/starrocks.py | 12 + sqlglot/dialects/tableau.py | 37 + sqlglot/dialects/trino.py | 10 + sqlglot/diff.py | 314 +++ sqlglot/errors.py | 38 + sqlglot/executor/__init__.py | 39 + sqlglot/executor/context.py | 68 + sqlglot/executor/env.py | 32 + sqlglot/executor/python.py | 360 +++ sqlglot/executor/table.py | 81 + sqlglot/expressions.py | 2945 +++++++++++++++++++++++ sqlglot/generator.py | 1124 +++++++++ sqlglot/helper.py | 123 + sqlglot/optimizer/__init__.py | 2 + sqlglot/optimizer/eliminate_subqueries.py | 48 + sqlglot/optimizer/expand_multi_table_selects.py | 16 + sqlglot/optimizer/isolate_table_selects.py | 31 + sqlglot/optimizer/normalize.py | 136 ++ sqlglot/optimizer/optimize_joins.py | 75 + sqlglot/optimizer/optimizer.py | 43 + sqlglot/optimizer/pushdown_predicates.py | 176 ++ sqlglot/optimizer/pushdown_projections.py | 85 + sqlglot/optimizer/qualify_columns.py | 422 ++++ sqlglot/optimizer/qualify_tables.py | 54 + sqlglot/optimizer/quote_identities.py | 25 + sqlglot/optimizer/schema.py | 129 + sqlglot/optimizer/scope.py | 438 ++++ sqlglot/optimizer/simplify.py | 383 +++ sqlglot/optimizer/unnest_subqueries.py | 220 ++ sqlglot/parser.py | 2190 +++++++++++++++++ sqlglot/planner.py | 340 +++ sqlglot/time.py | 45 + sqlglot/tokens.py | 853 +++++++ sqlglot/transforms.py | 68 + sqlglot/trie.py | 27 + 50 files changed, 12938 insertions(+) create mode 100644 sqlglot/__init__.py create mode 100644 sqlglot/__main__.py create mode 100644 sqlglot/dialects/__init__.py create mode 100644 sqlglot/dialects/bigquery.py create mode 100644 sqlglot/dialects/clickhouse.py create mode 100644 sqlglot/dialects/dialect.py create mode 100644 sqlglot/dialects/duckdb.py create mode 100644 sqlglot/dialects/hive.py create mode 100644 sqlglot/dialects/mysql.py create mode 100644 sqlglot/dialects/oracle.py create mode 100644 sqlglot/dialects/postgres.py create mode 100644 sqlglot/dialects/presto.py create mode 100644 sqlglot/dialects/snowflake.py create mode 100644 sqlglot/dialects/spark.py create mode 100644 sqlglot/dialects/sqlite.py create mode 100644 sqlglot/dialects/starrocks.py create mode 100644 sqlglot/dialects/tableau.py create mode 100644 sqlglot/dialects/trino.py create mode 100644 sqlglot/diff.py create mode 100644 sqlglot/errors.py create mode 100644 sqlglot/executor/__init__.py create mode 100644 sqlglot/executor/context.py create mode 100644 sqlglot/executor/env.py create mode 100644 sqlglot/executor/python.py create mode 100644 sqlglot/executor/table.py create mode 100644 sqlglot/expressions.py create mode 100644 sqlglot/generator.py create mode 100644 sqlglot/helper.py create mode 100644 sqlglot/optimizer/__init__.py create mode 100644 sqlglot/optimizer/eliminate_subqueries.py create mode 100644 sqlglot/optimizer/expand_multi_table_selects.py create mode 100644 sqlglot/optimizer/isolate_table_selects.py create mode 100644 sqlglot/optimizer/normalize.py create mode 100644 sqlglot/optimizer/optimize_joins.py create mode 100644 sqlglot/optimizer/optimizer.py create mode 100644 sqlglot/optimizer/pushdown_predicates.py create mode 100644 sqlglot/optimizer/pushdown_projections.py create mode 100644 sqlglot/optimizer/qualify_columns.py create mode 100644 sqlglot/optimizer/qualify_tables.py create mode 100644 sqlglot/optimizer/quote_identities.py create mode 100644 sqlglot/optimizer/schema.py create mode 100644 sqlglot/optimizer/scope.py create mode 100644 sqlglot/optimizer/simplify.py create mode 100644 sqlglot/optimizer/unnest_subqueries.py create mode 100644 sqlglot/parser.py create mode 100644 sqlglot/planner.py create mode 100644 sqlglot/time.py create mode 100644 sqlglot/tokens.py create mode 100644 sqlglot/transforms.py create mode 100644 sqlglot/trie.py (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py new file mode 100644 index 0000000..0007e34 --- /dev/null +++ b/sqlglot/__init__.py @@ -0,0 +1,96 @@ +from sqlglot import expressions as exp +from sqlglot.dialects import Dialect, Dialects +from sqlglot.diff import diff +from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError +from sqlglot.expressions import Expression +from sqlglot.expressions import alias_ as alias +from sqlglot.expressions import ( + and_, + column, + condition, + from_, + maybe_parse, + not_, + or_, + select, + subquery, +) +from sqlglot.expressions import table_ as table +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + +__version__ = "6.0.4" + +pretty = False + + +def parse(sql, read=None, **opts): + """ + Parses the given SQL string into a collection of syntax trees, one per + parsed SQL statement. + + Args: + sql (str): the SQL code string to parse. + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + **opts: other options. + + Returns: + typing.List[Expression]: the list of parsed syntax trees. + """ + dialect = Dialect.get_or_raise(read)() + return dialect.parse(sql, **opts) + + +def parse_one(sql, read=None, into=None, **opts): + """ + Parses the given SQL string and returns a syntax tree for the first + parsed SQL statement. + + Args: + sql (str): the SQL code string to parse. + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + into (Expression): the SQLGlot Expression to parse into + **opts: other options. + + Returns: + Expression: the syntax tree for the first parsed statement. + """ + + dialect = Dialect.get_or_raise(read)() + + if into: + result = dialect.parse_into(into, sql, **opts) + else: + result = dialect.parse(sql, **opts) + + return result[0] if result else None + + +def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts): + """ + Parses the given SQL string using the source dialect and returns a list of SQL strings + transformed to conform to the target dialect. Each string in the returned list represents + a single transformed SQL statement. + + Args: + sql (str): the SQL code string to transpile. + read (str): the source dialect used to parse the input string + (eg. "spark", "hive", "presto", "mysql"). + write (str): the target dialect into which the input should be transformed + (eg. "spark", "hive", "presto", "mysql"). + identity (bool): if set to True and if the target dialect is not specified + the source dialect will be used as both: the source and the target dialect. + error_level (ErrorLevel): the desired error level of the parser. + **opts: other options. + + Returns: + typing.List[str]: the list of transpiled SQL statements / expressions. + """ + write = write or read if identity else write + return [ + Dialect.get_or_raise(write)().generate(expression, **opts) + for expression in parse(sql, read, error_level=error_level) + ] diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py new file mode 100644 index 0000000..25200c4 --- /dev/null +++ b/sqlglot/__main__.py @@ -0,0 +1,69 @@ +import argparse + +import sqlglot + +parser = argparse.ArgumentParser(description="Transpile SQL") +parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile") +parser.add_argument( + "--read", + dest="read", + type=str, + default=None, + help="Dialect to read default is generic", +) +parser.add_argument( + "--write", + dest="write", + type=str, + default=None, + help="Dialect to write default is generic", +) +parser.add_argument( + "--no-identify", + dest="identify", + action="store_false", + help="Don't auto identify fields", +) +parser.add_argument( + "--no-pretty", + dest="pretty", + action="store_false", + help="Compress sql", +) +parser.add_argument( + "--parse", + dest="parse", + action="store_true", + help="Parse and return the expression tree", +) +parser.add_argument( + "--error-level", + dest="error_level", + type=str, + default="RAISE", + help="IGNORE, WARN, RAISE (default)", +) + + +args = parser.parse_args() +error_level = sqlglot.ErrorLevel[args.error_level.upper()] + +if args.parse: + sqls = [ + repr(expression) + for expression in sqlglot.parse( + args.sql, read=args.read, error_level=error_level + ) + ] +else: + sqls = sqlglot.transpile( + args.sql, + read=args.read, + write=args.write, + identify=args.identify, + pretty=args.pretty, + error_level=error_level, + ) + +for sql in sqls: + print(sql) diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py new file mode 100644 index 0000000..5aa7d77 --- /dev/null +++ b/sqlglot/dialects/__init__.py @@ -0,0 +1,15 @@ +from sqlglot.dialects.bigquery import BigQuery +from sqlglot.dialects.clickhouse import ClickHouse +from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.duckdb import DuckDB +from sqlglot.dialects.hive import Hive +from sqlglot.dialects.mysql import MySQL +from sqlglot.dialects.oracle import Oracle +from sqlglot.dialects.postgres import Postgres +from sqlglot.dialects.presto import Presto +from sqlglot.dialects.snowflake import Snowflake +from sqlglot.dialects.spark import Spark +from sqlglot.dialects.sqlite import SQLite +from sqlglot.dialects.starrocks import StarRocks +from sqlglot.dialects.tableau import Tableau +from sqlglot.dialects.trino import Trino diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py new file mode 100644 index 0000000..f4e87c3 --- /dev/null +++ b/sqlglot/dialects/bigquery.py @@ -0,0 +1,128 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + inline_array_sql, + no_ilike_sql, + rename_func, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_add(expression_class): + def func(args): + interval = list_get(args, 1) + return expression_class( + this=list_get(args, 0), + expression=interval.this, + unit=interval.args.get("unit"), + ) + + return func + + +def _date_add_sql(data_type, kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = self.sql(expression, "unit") or "'day'" + expression = self.sql(expression, "expression") + return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})" + + return func + + +class BigQuery(Dialect): + unnest_column_only = True + + class Tokenizer(Tokenizer): + QUOTES = [ + (prefix + quote, quote) if prefix else quote + for quote in ["'", '"', '"""', "'''"] + for prefix in ["", "r", "R"] + ] + IDENTIFIERS = ["`"] + ESCAPE = "\\" + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "CURRENT_TIME": TokenType.CURRENT_TIME, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + "QUALIFY": TokenType.QUALIFY, + "UNKNOWN": TokenType.NULL, + "WINDOW": TokenType.WINDOW, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "DATE_ADD": _date_add(exp.DateAdd), + "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "TIME_ADD": _date_add(exp.TimeAdd), + "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), + "DATE_SUB": _date_add(exp.DateSub), + "DATETIME_SUB": _date_add(exp.DatetimeSub), + "TIME_SUB": _date_add(exp.TimeSub), + "TIMESTAMP_SUB": _date_add(exp.TimestampSub), + } + + NO_PAREN_FUNCTIONS = { + **Parser.NO_PAREN_FUNCTIONS, + TokenType.CURRENT_DATETIME: exp.CurrentDatetime, + TokenType.CURRENT_TIME: exp.CurrentTime, + } + + class Generator(Generator): + TRANSFORMS = { + exp.Array: inline_array_sql, + exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.DateAdd: _date_add_sql("DATE", "ADD"), + exp.DateSub: _date_add_sql("DATE", "SUB"), + exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), + exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), + exp.ILike: no_ilike_sql, + exp.TimeAdd: _date_add_sql("TIME", "ADD"), + exp.TimeSub: _date_add_sql("TIME", "SUB"), + exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), + exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), + exp.VariancePop: rename_func("VAR_POP"), + } + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.NVARCHAR: "STRING", + } + + def in_unnest_op(self, unnest): + return self.sql(unnest) + + def union_op(self, expression): + return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def except_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") + return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def intersect_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported( + "INTERSECT without DISTINCT is not supported in BigQuery" + ) + return ( + f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + ) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py new file mode 100644 index 0000000..55dad7a --- /dev/null +++ b/sqlglot/dialects/clickhouse.py @@ -0,0 +1,48 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +class ClickHouse(Dialect): + normalize_functions = None + null_ordering = "nulls_are_last" + + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', "`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "NULLABLE": TokenType.NULLABLE, + "FINAL": TokenType.FINAL, + "INT8": TokenType.TINYINT, + "INT16": TokenType.SMALLINT, + "INT32": TokenType.INT, + "INT64": TokenType.BIGINT, + "FLOAT32": TokenType.FLOAT, + "FLOAT64": TokenType.DOUBLE, + } + + class Parser(Parser): + def _parse_table(self, schema=False): + this = super()._parse_table(schema) + + if self._match(TokenType.FINAL): + this = self.expression(exp.Final, this=this) + + return this + + class Generator(Generator): + STRUCT_DELIMITER = ("(", ")") + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.NULLABLE: "Nullable", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.Array: inline_array_sql, + exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py new file mode 100644 index 0000000..8045f7a --- /dev/null +++ b/sqlglot/dialects/dialect.py @@ -0,0 +1,268 @@ +from enum import Enum + +from sqlglot import exp +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.time import format_time +from sqlglot.tokens import Tokenizer +from sqlglot.trie import new_trie + + +class Dialects(str, Enum): + DIALECT = "" + + BIGQUERY = "bigquery" + CLICKHOUSE = "clickhouse" + DUCKDB = "duckdb" + HIVE = "hive" + MYSQL = "mysql" + ORACLE = "oracle" + POSTGRES = "postgres" + PRESTO = "presto" + SNOWFLAKE = "snowflake" + SPARK = "spark" + SQLITE = "sqlite" + STARROCKS = "starrocks" + TABLEAU = "tableau" + TRINO = "trino" + + +class _Dialect(type): + classes = {} + + @classmethod + def __getitem__(cls, key): + return cls.classes[key] + + @classmethod + def get(cls, key, default=None): + return cls.classes.get(key, default) + + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + enum = Dialects.__members__.get(clsname.upper()) + cls.classes[enum.value if enum is not None else clsname.lower()] = klass + + klass.time_trie = new_trie(klass.time_mapping) + klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} + klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) + + klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) + klass.parser_class = getattr(klass, "Parser", Parser) + klass.generator_class = getattr(klass, "Generator", Generator) + + klass.tokenizer = klass.tokenizer_class() + klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[ + 0 + ] + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class.IDENTIFIERS.items() + )[0] + + return klass + + +class Dialect(metaclass=_Dialect): + index_offset = 0 + unnest_column_only = False + alias_post_tablesample = False + normalize_functions = "upper" + null_ordering = "nulls_are_small" + + date_format = "'%Y-%m-%d'" + dateint_format = "'%Y%m%d'" + time_format = "'%Y-%m-%d %H:%M:%S'" + time_mapping = {} + + # autofilled + quote_start = None + quote_end = None + identifier_start = None + identifier_end = None + + time_trie = None + inverse_time_mapping = None + inverse_time_trie = None + tokenizer_class = None + parser_class = None + generator_class = None + tokenizer = None + + @classmethod + def get_or_raise(cls, dialect): + if not dialect: + return cls + result = cls.get(dialect) + if not result: + raise ValueError(f"Unknown dialect '{dialect}'") + return result + + @classmethod + def format_time(cls, expression): + if isinstance(expression, str): + return exp.Literal.string( + format_time( + expression[1:-1], # the time formats are quoted + cls.time_mapping, + cls.time_trie, + ) + ) + if expression and expression.is_string: + return exp.Literal.string( + format_time( + expression.this, + cls.time_mapping, + cls.time_trie, + ) + ) + return expression + + def parse(self, sql, **opts): + return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) + + def parse_into(self, expression_type, sql, **opts): + return self.parser(**opts).parse_into( + expression_type, self.tokenizer.tokenize(sql), sql + ) + + def generate(self, expression, **opts): + return self.generator(**opts).generate(expression) + + def transpile(self, code, **opts): + return self.generate(self.parse(code), **opts) + + def parser(self, **opts): + return self.parser_class( + **{ + "index_offset": self.index_offset, + "unnest_column_only": self.unnest_column_only, + "alias_post_tablesample": self.alias_post_tablesample, + "null_ordering": self.null_ordering, + **opts, + }, + ) + + def generator(self, **opts): + return self.generator_class( + **{ + "quote_start": self.quote_start, + "quote_end": self.quote_end, + "identifier_start": self.identifier_start, + "identifier_end": self.identifier_end, + "escape": self.tokenizer_class.ESCAPE, + "index_offset": self.index_offset, + "time_mapping": self.inverse_time_mapping, + "time_trie": self.inverse_time_trie, + "unnest_column_only": self.unnest_column_only, + "alias_post_tablesample": self.alias_post_tablesample, + "normalize_functions": self.normalize_functions, + "null_ordering": self.null_ordering, + **opts, + } + ) + + +def rename_func(name): + return ( + lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" + ) + + +def approx_count_distinct_sql(self, expression): + if expression.args.get("accuracy"): + self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") + return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})" + + +def if_sql(self, expression): + expressions = csv( + self.sql(expression, "this"), + self.sql(expression, "true"), + self.sql(expression, "false"), + ) + return f"IF({expressions})" + + +def arrow_json_extract_sql(self, expression): + return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}" + + +def arrow_json_extract_scalar_sql(self, expression): + return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}" + + +def inline_array_sql(self, expression): + return f"[{self.expressions(expression)}]" + + +def no_ilike_sql(self, expression): + return self.like_sql( + exp.Like( + this=exp.Lower(this=expression.this), + expression=expression.args["expression"], + ) + ) + + +def no_paren_current_date_sql(self, expression): + zone = self.sql(expression, "this") + return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" + + +def no_recursive_cte_sql(self, expression): + if expression.args.get("recursive"): + self.unsupported("Recursive CTEs are unsupported") + expression.args["recursive"] = False + return self.with_sql(expression) + + +def no_safe_divide_sql(self, expression): + n = self.sql(expression, "this") + d = self.sql(expression, "expression") + return f"IF({d} <> 0, {n} / {d}, NULL)" + + +def no_tablesample_sql(self, expression): + self.unsupported("TABLESAMPLE unsupported") + return self.sql(expression.this) + + +def no_trycast_sql(self, expression): + return self.cast_sql(expression) + + +def str_position_sql(self, expression): + this = self.sql(expression, "this") + substr = self.sql(expression, "substr") + position = self.sql(expression, "position") + if position: + return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" + return f"STRPOS({this}, {substr})" + + +def struct_extract_sql(self, expression): + this = self.sql(expression, "this") + struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) + return f"{this}.{struct_key}" + + +def format_time_lambda(exp_class, dialect, default=None): + """Helper used for time expressions. + + Args + exp_class (Class): the expression class to instantiate + dialect (string): sql dialect + default (Option[bool | str]): the default format, True being time + """ + + def _format_time(args): + return exp_class( + this=list_get(args, 0), + format=Dialect[dialect].format_time( + list_get(args, 1) + or (Dialect[dialect].time_format if default is True else default) + ), + ) + + return _format_time diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py new file mode 100644 index 0000000..d83a620 --- /dev/null +++ b/sqlglot/dialects/duckdb.py @@ -0,0 +1,156 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + approx_count_distinct_sql, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + format_time_lambda, + no_safe_divide_sql, + no_tablesample_sql, + rename_func, + str_position_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _unix_to_time(self, expression): + return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))" + + +def _str_to_time_sql(self, expression): + return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_add(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit").strip("'") or "DAY" + return f"CAST({this} AS DATE) + INTERVAL {e} {unit}" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (DuckDB.time_format, DuckDB.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def _date_add(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit").strip("'") or "DAY" + return f"{this} + INTERVAL {e} {unit}" + + +def _array_sort_sql(self, expression): + if expression.expression: + self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") + return f"ARRAY_SORT({self.sql(expression, 'this')})" + + +def _sort_array_sql(self, expression): + this = self.sql(expression, "this") + if expression.args.get("asc") == exp.FALSE: + return f"ARRAY_REVERSE_SORT({this})" + return f"ARRAY_SORT({this})" + + +def _sort_array_reverse(args): + return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) + + +def _struct_pack_sql(self, expression): + args = [ + self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) + for e in expression.expressions + ] + return f"STRUCT_PACK({', '.join(args)})" + + +class DuckDB(Dialect): + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + ":=": TokenType.EQ, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "ARRAY_LENGTH": exp.ArraySize.from_arg_list, + "ARRAY_SORT": exp.SortArray.from_arg_list, + "ARRAY_REVERSE_SORT": _sort_array_reverse, + "EPOCH": exp.TimeToUnix.from_arg_list, + "EPOCH_MS": lambda args: exp.UnixToTime( + this=exp.Div( + this=list_get(args, 0), + expression=exp.Literal.number(1000), + ) + ), + "LIST_SORT": exp.SortArray.from_arg_list, + "LIST_REVERSE_SORT": _sort_array_reverse, + "LIST_VALUE": exp.Array.from_arg_list, + "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, + "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), + "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), + "STR_SPLIT": exp.Split.from_arg_list, + "STRING_SPLIT": exp.Split.from_arg_list, + "STRING_TO_ARRAY": exp.Split.from_arg_list, + "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRUCT_PACK": exp.Struct.from_arg_list, + "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "UNNEST": exp.Explode.from_arg_list, + } + + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})", + exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.ArraySort: _array_sort_sql, + exp.ArraySum: rename_func("LIST_SUM"), + exp.DateAdd: _date_add, + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", + exp.Explode: rename_func("UNNEST"), + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.RegexpLike: rename_func("REGEXP_MATCHES"), + exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), + exp.SafeDivide: no_safe_divide_sql, + exp.Split: rename_func("STR_SPLIT"), + exp.SortArray: _sort_array_sql, + exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.Struct: _struct_pack_sql, + exp.TableSample: no_tablesample_sql, + exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", + exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("EPOCH"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: _ts_or_ds_add, + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)", + } + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.NVARCHAR: "TEXT", + } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py new file mode 100644 index 0000000..e3f3f39 --- /dev/null +++ b/sqlglot/dialects/hive.py @@ -0,0 +1,304 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import ( + Dialect, + approx_count_distinct_sql, + format_time_lambda, + if_sql, + no_ilike_sql, + no_recursive_cte_sql, + no_safe_divide_sql, + no_trycast_sql, + rename_func, + struct_extract_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer + + +def _parse_map(args): + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + return HiveMap( + keys=exp.Array(expressions=keys), + values=exp.Array(expressions=values), + ) + + +def _map_sql(self, expression): + keys = expression.args["keys"] + values = expression.args["values"] + + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + self.unsupported("Cannot convert array columns into map use SparkSQL instead.") + return f"MAP({self.sql(keys)}, {self.sql(values)})" + + args = [] + for key, value in zip(keys.expressions, values.expressions): + args.append(self.sql(key)) + args.append(self.sql(value)) + return f"MAP({csv(*args)})" + + +def _array_sort(self, expression): + if expression.expression: + self.unsupported("Hive SORT_ARRAY does not support a comparator") + return f"SORT_ARRAY({self.sql(expression, 'this')})" + + +def _property_sql(self, expression): + key = expression.name + value = self.sql(expression, "value") + return f"'{key}' = {value}" + + +def _str_to_unix(self, expression): + return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format not in (Hive.time_format, Hive.date_format): + this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" + return f"CAST({this} AS DATE)" + + +def _str_to_time(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format not in (Hive.time_format, Hive.date_format): + this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" + return f"CAST({this} AS TIMESTAMP)" + + +def _time_format(self, expression): + time_format = self.format_time(expression) + if time_format == Hive.time_format: + return None + return time_format + + +def _time_to_str(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + return f"DATE_FORMAT({this}, {time_format})" + + +def _to_date_sql(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format and time_format not in (Hive.time_format, Hive.date_format): + return f"TO_DATE({this}, {time_format})" + return f"TO_DATE({this})" + + +def _unnest_to_explode_sql(self, expression): + unnest = expression.this + if isinstance(unnest, exp.Unnest): + alias = unnest.args.get("alias") + udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + return "".join( + self.sql( + exp.Lateral( + this=udtf(this=expression), + alias=exp.TableAlias(this=alias.this, columns=[column]), + ) + ) + for expression, column in zip( + unnest.expressions, alias.columns if alias else [] + ) + ) + return self.join_sql(expression) + + +def _index_sql(self, expression): + this = self.sql(expression, "this") + table = self.sql(expression, "table") + columns = self.sql(expression, "columns") + return f"{this} ON TABLE {table} {columns}" + + +class HiveMap(exp.Map): + is_var_len_args = True + + +class Hive(Dialect): + alias_post_tablesample = True + + time_mapping = { + "y": "%Y", + "Y": "%Y", + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "hh": "%I", + "h": "%-I", + "mm": "%M", + "m": "%-M", + "ss": "%S", + "s": "%-S", + "S": "%f", + } + + date_format = "'yyyy-MM-dd'" + dateint_format = "'yyyyMMdd'" + time_format = "'yyyy-MM-dd HH:mm:ss'" + + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + IDENTIFIERS = ["`"] + ESCAPE = "\\" + ENCODE = "utf-8" + + NUMERIC_LITERALS = { + "L": "BIGINT", + "S": "SMALLINT", + "Y": "TINYINT", + "D": "DOUBLE", + "F": "FLOAT", + "BD": "DECIMAL", + } + + class Parser(Parser): + STRICT_CAST = False + + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "COLLECT_LIST": exp.ArrayAgg.from_arg_list, + "DATE_ADD": lambda args: exp.TsOrDsAdd( + this=list_get(args, 0), + expression=list_get(args, 1), + unit=exp.Literal.string("DAY"), + ), + "DATEDIFF": lambda args: exp.DateDiff( + this=exp.TsOrDsToDate(this=list_get(args, 0)), + expression=exp.TsOrDsToDate(this=list_get(args, 1)), + ), + "DATE_SUB": lambda args: exp.TsOrDsAdd( + this=list_get(args, 0), + expression=exp.Mul( + this=list_get(args, 1), + expression=exp.Literal.number(-1), + ), + unit=exp.Literal.string("DAY"), + ), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), + "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, + "LOCATE": lambda args: exp.StrPosition( + this=list_get(args, 1), + substr=list_get(args, 0), + position=list_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) + ), + "MAP": _parse_map, + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), + "PERCENTILE": exp.Quantile.from_arg_list, + "COLLECT_SET": exp.SetAgg.from_arg_list, + "SIZE": exp.ArraySize.from_arg_list, + "SPLIT": exp.RegexpSplit.from_arg_list, + "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), + "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), + "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), + } + + class Generator(Generator): + ROOT_PROPERTIES = [ + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.LocationProperty, + exp.TableFormatProperty, + ] + WITH_PROPERTIES = [exp.AnonymousProperty] + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.AnonymousProperty: _property_sql, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.ArraySize: rename_func("SIZE"), + exp.ArraySort: _array_sort, + exp.With: no_recursive_cte_sql, + exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", + exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", + exp.If: if_sql, + exp.Index: _index_sql, + exp.ILike: no_ilike_sql, + exp.Join: _unnest_to_explode_sql, + exp.JSONExtract: rename_func("GET_JSON_OBJECT"), + exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), + exp.Map: _map_sql, + HiveMap: _map_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", + exp.Quantile: rename_func("PERCENTILE"), + exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), + exp.RegexpSplit: rename_func("SPLIT"), + exp.SafeDivide: no_safe_divide_sql, + exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}", + exp.SetAgg: rename_func("COLLECT_SET"), + exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", + exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", + exp.StrToDate: _str_to_date, + exp.StrToTime: _str_to_time, + exp.StrToUnix: _str_to_unix, + exp.StructExtract: struct_extract_sql, + exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: _time_to_str, + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.TsOrDsToDate: _to_date_sql, + exp.TryCast: no_trycast_sql, + exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), + } + + def with_properties(self, properties): + return self.properties( + properties, + prefix="TBLPROPERTIES", + ) + + def datatype_sql(self, expression): + if ( + expression.this + in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) + and not expression.expressions + ): + expression = exp.DataType.build("text") + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py new file mode 100644 index 0000000..93800a6 --- /dev/null +++ b/sqlglot/dialects/mysql.py @@ -0,0 +1,163 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + no_ilike_sql, + no_paren_current_date_sql, + no_tablesample_sql, + no_trycast_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_trunc_sql(self, expression): + unit = expression.text("unit").lower() + + this = self.sql(expression.this) + + if unit == "day": + return f"DATE({this})" + + if unit == "week": + concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + date_format = "%Y %u %w" + elif unit == "month": + concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + date_format = "%Y %c %e" + elif unit == "quarter": + concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + date_format = "%Y %c %e" + elif unit == "year": + concat = f"CONCAT(YEAR({this}), ' 1 1')" + date_format = "%Y %c %e" + else: + self.unsupported("Unexpected interval unit: {unit}") + return f"DATE({this})" + + return f"STR_TO_DATE({concat}, '{date_format}')" + + +def _str_to_date(args): + date_format = MySQL.format_time(list_get(args, 1)) + return exp.StrToDate(this=list_get(args, 0), format=date_format) + + +def _str_to_date_sql(self, expression): + date_format = self.format_time(expression) + return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" + + +def _date_add(expression_class): + def func(args): + interval = list_get(args, 1) + return expression_class( + this=list_get(args, 0), + expression=interval.this, + unit=exp.Literal.string(interval.text("unit").lower()), + ) + + return func + + +def _date_add_sql(kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() or "DAY" + expression = self.sql(expression, "expression") + return f"DATE_{kind}({this}, INTERVAL {expression} {unit})" + + return func + + +class MySQL(Dialect): + # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions + time_mapping = { + "%M": "%B", + "%c": "%-m", + "%e": "%-d", + "%h": "%I", + "%i": "%M", + "%s": "%S", + "%S": "%S", + "%u": "%W", + } + + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + COMMENTS = ["--", "#", ("/*", "*/")] + IDENTIFIERS = ["`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "_ARMSCII8": TokenType.INTRODUCER, + "_ASCII": TokenType.INTRODUCER, + "_BIG5": TokenType.INTRODUCER, + "_BINARY": TokenType.INTRODUCER, + "_CP1250": TokenType.INTRODUCER, + "_CP1251": TokenType.INTRODUCER, + "_CP1256": TokenType.INTRODUCER, + "_CP1257": TokenType.INTRODUCER, + "_CP850": TokenType.INTRODUCER, + "_CP852": TokenType.INTRODUCER, + "_CP866": TokenType.INTRODUCER, + "_CP932": TokenType.INTRODUCER, + "_DEC8": TokenType.INTRODUCER, + "_EUCJPMS": TokenType.INTRODUCER, + "_EUCKR": TokenType.INTRODUCER, + "_GB18030": TokenType.INTRODUCER, + "_GB2312": TokenType.INTRODUCER, + "_GBK": TokenType.INTRODUCER, + "_GEOSTD8": TokenType.INTRODUCER, + "_GREEK": TokenType.INTRODUCER, + "_HEBREW": TokenType.INTRODUCER, + "_HP8": TokenType.INTRODUCER, + "_KEYBCS2": TokenType.INTRODUCER, + "_KOI8R": TokenType.INTRODUCER, + "_KOI8U": TokenType.INTRODUCER, + "_LATIN1": TokenType.INTRODUCER, + "_LATIN2": TokenType.INTRODUCER, + "_LATIN5": TokenType.INTRODUCER, + "_LATIN7": TokenType.INTRODUCER, + "_MACCE": TokenType.INTRODUCER, + "_MACROMAN": TokenType.INTRODUCER, + "_SJIS": TokenType.INTRODUCER, + "_SWE7": TokenType.INTRODUCER, + "_TIS620": TokenType.INTRODUCER, + "_UCS2": TokenType.INTRODUCER, + "_UJIS": TokenType.INTRODUCER, + "_UTF8": TokenType.INTRODUCER, + "_UTF16": TokenType.INTRODUCER, + "_UTF16LE": TokenType.INTRODUCER, + "_UTF32": TokenType.INTRODUCER, + "_UTF8MB3": TokenType.INTRODUCER, + "_UTF8MB4": TokenType.INTRODUCER, + } + + class Parser(Parser): + STRICT_CAST = False + + FUNCTIONS = { + **Parser.FUNCTIONS, + "DATE_ADD": _date_add(exp.DateAdd), + "DATE_SUB": _date_add(exp.DateSub), + "STR_TO_DATE": _str_to_date, + } + + class Generator(Generator): + NULL_ORDERING_SUPPORTED = False + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.ILike: no_ilike_sql, + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + exp.DateAdd: _date_add_sql("ADD"), + exp.DateSub: _date_add_sql("SUB"), + exp.DateTrunc: _date_trunc_sql, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_date_sql, + } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py new file mode 100644 index 0000000..9c8b6f2 --- /dev/null +++ b/sqlglot/dialects/oracle.py @@ -0,0 +1,63 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import Dialect, no_ilike_sql +from sqlglot.generator import Generator +from sqlglot.helper import csv +from sqlglot.tokens import Tokenizer, TokenType + + +def _limit_sql(self, expression): + return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) + + +class Oracle(Dialect): + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "NUMBER", + exp.DataType.Type.SMALLINT: "NUMBER", + exp.DataType.Type.INT: "NUMBER", + exp.DataType.Type.BIGINT: "NUMBER", + exp.DataType.Type.DECIMAL: "NUMBER", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", + exp.DataType.Type.VARCHAR: "VARCHAR2", + exp.DataType.Type.NVARCHAR: "NVARCHAR2", + exp.DataType.Type.TEXT: "CLOB", + exp.DataType.Type.BINARY: "BLOB", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.ILike: no_ilike_sql, + exp.Limit: _limit_sql, + } + + def query_modifiers(self, expression, *sqls): + return csv( + *sqls, + *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins", [])], + self.sql(expression, "where"), + self.sql(expression, "group"), + self.sql(expression, "having"), + self.sql(expression, "qualify"), + self.sql(expression, "window"), + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + self.sql(expression, "order"), + self.sql(expression, "offset"), # offset before limit in oracle + self.sql(expression, "limit"), + sep="", + ) + + def offset_sql(self, expression): + return f"{super().offset_sql(expression)} ROWS" + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "TOP": TokenType.TOP, + "VARCHAR2": TokenType.VARCHAR, + "NVARCHAR2": TokenType.NVARCHAR, + } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py new file mode 100644 index 0000000..61dff86 --- /dev/null +++ b/sqlglot/dialects/postgres.py @@ -0,0 +1,109 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + format_time_lambda, + no_paren_current_date_sql, + no_tablesample_sql, + no_trycast_sql, +) +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_add_sql(kind): + def func(self, expression): + from sqlglot.optimizer.simplify import simplify + + this = self.sql(expression, "this") + unit = self.sql(expression, "unit") + expression = simplify(expression.args["expression"]) + + if not isinstance(expression, exp.Literal): + self.unsupported("Cannot add non literal") + + expression = expression.copy() + expression.args["is_string"] = True + expression = self.sql(expression) + return f"{this} {kind} INTERVAL {expression} {unit}" + + return func + + +class Postgres(Dialect): + null_ordering = "nulls_are_large" + time_format = "'YYYY-MM-DD HH24:MI:SS'" + time_mapping = { + "AM": "%p", # AM or PM + "D": "%w", # 1-based day of week + "DD": "%d", # day of month + "DDD": "%j", # zero padded day of year + "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres + "FMDDD": "%-j", # day of year + "FMHH12": "%-I", # 9 + "FMHH24": "%-H", # 9 + "FMMI": "%-M", # Minute + "FMMM": "%-m", # 1 + "FMSS": "%-S", # Second + "HH12": "%I", # 09 + "HH24": "%H", # 09 + "MI": "%M", # zero padded minute + "MM": "%m", # 01 + "OF": "%z", # utc offset + "SS": "%S", # zero padded second + "TMDay": "%A", # TM is locale dependent + "TMDy": "%a", + "TMMon": "%b", # Sep + "TMMonth": "%B", # September + "TZ": "%Z", # uppercase timezone name + "US": "%f", # zero padded microsecond + "WW": "%U", # 1-based week of year + "YY": "%y", # 15 + "YYYY": "%Y", # 2015 + } + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "SERIAL": TokenType.AUTO_INCREMENT, + "UUID": TokenType.UUID, + } + + class Parser(Parser): + STRICT_CAST = False + FUNCTIONS = { + **Parser.FUNCTIONS, + "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), + "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "SMALLINT", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", + exp.DataType.Type.BINARY: "BYTEA", + } + + TOKEN_MAPPING = { + TokenType.AUTO_INCREMENT: "SERIAL", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", + exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.DateAdd: _date_add_sql("+"), + exp.DateSub: _date_add_sql("-"), + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py new file mode 100644 index 0000000..ca913e4 --- /dev/null +++ b/sqlglot/dialects/presto.py @@ -0,0 +1,216 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + if_sql, + no_ilike_sql, + no_safe_divide_sql, + rename_func, + str_position_sql, + struct_extract_sql, +) +from sqlglot.dialects.mysql import MySQL +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _approx_distinct_sql(self, expression): + accuracy = expression.args.get("accuracy") + accuracy = ", " + self.sql(accuracy) if accuracy else "" + return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" + + +def _concat_ws_sql(self, expression): + sep, *args = expression.expressions + sep = self.sql(sep) + if len(args) > 1: + return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})" + return f"ARRAY_JOIN({self.sql(args[0])}, {sep})" + + +def _datatype_sql(self, expression): + sql = self.datatype_sql(expression) + if expression.this == exp.DataType.Type.TIMESTAMPTZ: + sql = f"{sql} WITH TIME ZONE" + return sql + + +def _date_parse_sql(self, expression): + return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')" + + +def _explode_to_unnest_sql(self, expression): + if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + return self.sql( + exp.Join( + this=exp.Unnest( + expressions=[expression.this.this], + alias=expression.args.get("alias"), + ordinality=isinstance(expression.this, exp.Posexplode), + ), + kind="cross", + ) + ) + return self.lateral_sql(expression) + + +def _initcap_sql(self, expression): + regex = r"(\w)(\w*)" + return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" + + +def _no_sort_array(self, expression): + if expression.args.get("asc") == exp.FALSE: + comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" + else: + comparator = None + args = csv(self.sql(expression, "this"), comparator) + return f"ARRAY_SORT({args})" + + +def _schema_sql(self, expression): + if isinstance(expression.parent, exp.Property): + columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) + return f"ARRAY[{columns}]" + + for schema in expression.parent.find_all(exp.Schema): + if isinstance(schema.parent, exp.Property): + expression = expression.copy() + expression.expressions.extend(schema.expressions) + + return self.schema_sql(expression) + + +def _quantile_sql(self, expression): + self.unsupported("Presto does not support exact quantiles") + return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" + + +def _str_to_time_sql(self, expression): + return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (Presto.time_format, Presto.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return ( + f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" + ) + + +def _ts_or_ds_add_sql(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit") or "'day'" + return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" + + +class Presto(Dialect): + index_offset = 1 + null_ordering = "nulls_are_last" + time_format = "'%Y-%m-%d %H:%i:%S'" + time_mapping = MySQL.time_mapping + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "ROW": TokenType.STRUCT, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, + "CARDINALITY": exp.ArraySize.from_arg_list, + "CONTAINS": exp.ArrayContains.from_arg_list, + "DATE_ADD": lambda args: exp.DateAdd( + this=list_get(args, 2), + expression=list_get(args, 1), + unit=list_get(args, 0), + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=list_get(args, 2), + expression=list_get(args, 1), + unit=list_get(args, 0), + ), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), + "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), + "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, + "STRPOS": exp.StrPosition.from_arg_list, + "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, + } + + class Generator(Generator): + + STRUCT_DELIMITER = ("(", ")") + + WITH_PROPERTIES = [ + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.AnonymousProperty, + exp.TableFormatProperty, + ] + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.BINARY: "VARBINARY", + exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.STRUCT: "ROW", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.ApproxDistinct: _approx_distinct_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.ArrayContains: rename_func("CONTAINS"), + exp.ArraySize: rename_func("CARDINALITY"), + exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", + exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.ConcatWs: _concat_ws_sql, + exp.DataType: _datatype_sql, + exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.FileFormatProperty: lambda self, e: self.property_sql(e), + exp.If: if_sql, + exp.ILike: no_ilike_sql, + exp.Initcap: _initcap_sql, + exp.Lateral: _explode_to_unnest_sql, + exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", + exp.Quantile: _quantile_sql, + exp.SafeDivide: no_safe_divide_sql, + exp.Schema: _schema_sql, + exp.SortArray: _no_sort_array, + exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StructExtract: struct_extract_sql, + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", + exp.TimeStrToDate: _date_parse_sql, + exp.TimeStrToTime: _date_parse_sql, + exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", + exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("TO_UNIXTIME"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py new file mode 100644 index 0000000..148dfb5 --- /dev/null +++ b/sqlglot/dialects/snowflake.py @@ -0,0 +1,145 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func +from sqlglot.expressions import Literal +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _check_int(s): + if s[0] in ("-", "+"): + return s[1:].isdigit() + return s.isdigit() + + +# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html +def _snowflake_to_timestamp(args): + if len(args) == 2: + first_arg, second_arg = args + if second_arg.is_string: + # case: [ , ] + return format_time_lambda(exp.StrToTime, "snowflake")(args) + + # case: [ , ] + if second_arg.name not in ["0", "3", "9"]: + raise ValueError( + f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" + ) + + if second_arg.name == "0": + timescale = exp.UnixToTime.SECONDS + elif second_arg.name == "3": + timescale = exp.UnixToTime.MILLIS + elif second_arg.name == "9": + timescale = exp.UnixToTime.MICROS + + return exp.UnixToTime(this=first_arg, scale=timescale) + + first_arg = list_get(args, 0) + if not isinstance(first_arg, Literal): + # case: + return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + + if first_arg.is_string: + if _check_int(first_arg.this): + # case: + return exp.UnixToTime.from_arg_list(args) + + # case: + return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + + # case: + return exp.UnixToTime.from_arg_list(args) + + +def _unix_to_time(self, expression): + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in [None, exp.UnixToTime.SECONDS]: + return f"TO_TIMESTAMP({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TO_TIMESTAMP({timestamp}, 3)" + if scale == exp.UnixToTime.MICROS: + return f"TO_TIMESTAMP({timestamp}, 9)" + + raise ValueError("Improper scale for timestamp") + + +class Snowflake(Dialect): + null_ordering = "nulls_are_large" + time_format = "'yyyy-mm-dd hh24:mi:ss'" + + time_mapping = { + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "mmmm": "%B", + "MON": "%b", + "mon": "%b", + "MM": "%m", + "mm": "%m", + "DD": "%d", + "dd": "%d", + "d": "%-d", + "DY": "%w", + "dy": "%w", + "HH24": "%H", + "hh24": "%H", + "HH12": "%I", + "hh12": "%I", + "MI": "%M", + "mi": "%M", + "SS": "%S", + "ss": "%S", + "FF": "%f", + "ff": "%f", + "FF6": "%f", + "ff6": "%f", + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "IFF": exp.If.from_arg_list, + "TO_TIMESTAMP": _snowflake_to_timestamp, + } + + COLUMN_OPERATORS = { + **Parser.COLUMN_OPERATORS, + TokenType.COLON: lambda self, this, path: self.expression( + exp.Bracket, + this=this, + expressions=[path], + ), + } + + class Tokenizer(Tokenizer): + QUOTES = ["'", "$$"] + ESCAPE = "\\" + KEYWORDS = { + **Tokenizer.KEYWORDS, + "QUALIFY": TokenType.QUALIFY, + "DOUBLE PRECISION": TokenType.DOUBLE, + } + + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.If: rename_func("IFF"), + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + } + + def except_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("EXCEPT with All is not supported in Snowflake") + return super().except_op(expression) + + def intersect_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("INTERSECT with All is not supported in Snowflake") + return super().intersect_op(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py new file mode 100644 index 0000000..89c7ed5 --- /dev/null +++ b/sqlglot/dialects/spark.py @@ -0,0 +1,106 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import no_ilike_sql, rename_func +from sqlglot.dialects.hive import Hive, HiveMap +from sqlglot.helper import list_get + + +def _create_sql(self, e): + kind = e.args.get("kind") + temporary = e.args.get("temporary") + + if kind.upper() == "TABLE" and temporary is True: + return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" + return self.create_sql(e) + + +def _map_sql(self, expression): + keys = self.sql(expression.args["keys"]) + values = self.sql(expression.args["values"]) + return f"MAP_FROM_ARRAYS({keys}, {values})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Hive.date_format: + return f"TO_DATE({this})" + return f"TO_DATE({this}, {time_format})" + + +def _unix_to_time(self, expression): + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale is None: + return f"FROM_UNIXTIME({timestamp})" + if scale == exp.UnixToTime.SECONDS: + return f"TIMESTAMP_SECONDS({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TIMESTAMP_MILLIS({timestamp})" + if scale == exp.UnixToTime.MICROS: + return f"TIMESTAMP_MICROS({timestamp})" + + raise ValueError("Improper scale for timestamp") + + +class Spark(Hive): + class Parser(Hive.Parser): + FUNCTIONS = { + **Hive.Parser.FUNCTIONS, + "MAP_FROM_ARRAYS": exp.Map.from_arg_list, + "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "LEFT": lambda args: exp.Substring( + this=list_get(args, 0), + start=exp.Literal.number(1), + length=list_get(args, 1), + ), + "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( + this=list_get(args, 0), + expression=list_get(args, 1), + ), + "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( + this=list_get(args, 0), + expression=list_get(args, 1), + ), + "RIGHT": lambda args: exp.Substring( + this=list_get(args, 0), + start=exp.Sub( + this=exp.Length(this=list_get(args, 0)), + expression=exp.Add( + this=list_get(args, 1), expression=exp.Literal.number(1) + ), + ), + length=list_get(args, 1), + ), + } + + class Generator(Hive.Generator): + TYPE_MAPPING = { + **Hive.Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "BYTE", + exp.DataType.Type.SMALLINT: "SHORT", + exp.DataType.Type.BIGINT: "LONG", + } + + TRANSFORMS = { + **{ + k: v + for k, v in Hive.Generator.TRANSFORMS.items() + if k not in {exp.ArraySort} + }, + exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), + exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", + exp.ILike: no_ilike_sql, + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + exp.Create: _create_sql, + exp.Map: _map_sql, + exp.Reduce: rename_func("AGGREGATE"), + exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", + HiveMap: _map_sql, + } + + def bitstring_sql(self, expression): + return f"X'{self.sql(expression, 'this')}'" diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py new file mode 100644 index 0000000..6cf5022 --- /dev/null +++ b/sqlglot/dialects/sqlite.py @@ -0,0 +1,63 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + no_ilike_sql, + no_tablesample_sql, + no_trycast_sql, + rename_func, +) +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +class SQLite(Dialect): + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', ("[", "]"), "`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "AUTOINCREMENT": TokenType.AUTO_INCREMENT, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "EDITDIST3": exp.Levenshtein.from_arg_list, + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "INTEGER", + exp.DataType.Type.TINYINT: "INTEGER", + exp.DataType.Type.SMALLINT: "INTEGER", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.BIGINT: "INTEGER", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.DOUBLE: "REAL", + exp.DataType.Type.DECIMAL: "REAL", + exp.DataType.Type.CHAR: "TEXT", + exp.DataType.Type.NCHAR: "TEXT", + exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.NVARCHAR: "TEXT", + exp.DataType.Type.BINARY: "BLOB", + } + + TOKEN_MAPPING = { + TokenType.AUTO_INCREMENT: "AUTOINCREMENT", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.ILike: no_ilike_sql, + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.Levenshtein: rename_func("EDITDIST3"), + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + } diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py new file mode 100644 index 0000000..b9cd584 --- /dev/null +++ b/sqlglot/dialects/starrocks.py @@ -0,0 +1,12 @@ +from sqlglot import exp +from sqlglot.dialects.mysql import MySQL + + +class StarRocks(MySQL): + class Generator(MySQL.Generator): + TYPE_MAPPING = { + **MySQL.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "DATETIME", + } diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py new file mode 100644 index 0000000..e571749 --- /dev/null +++ b/sqlglot/dialects/tableau.py @@ -0,0 +1,37 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser + + +def _if_sql(self, expression): + return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END" + + +def _coalesce_sql(self, expression): + return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" + + +def _count_sql(self, expression): + this = expression.this + if isinstance(this, exp.Distinct): + return f"COUNTD({self.sql(this, 'this')})" + return f"COUNT({self.sql(expression, 'this')})" + + +class Tableau(Dialect): + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.If: _if_sql, + exp.Coalesce: _coalesce_sql, + exp.Count: _count_sql, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "IFNULL": exp.Coalesce.from_arg_list, + "COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))), + } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py new file mode 100644 index 0000000..805106c --- /dev/null +++ b/sqlglot/dialects/trino.py @@ -0,0 +1,10 @@ +from sqlglot import exp +from sqlglot.dialects.presto import Presto + + +class Trino(Presto): + class Generator(Presto.Generator): + TRANSFORMS = { + **Presto.Generator.TRANSFORMS, + exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + } diff --git a/sqlglot/diff.py b/sqlglot/diff.py new file mode 100644 index 0000000..8eeb4e9 --- /dev/null +++ b/sqlglot/diff.py @@ -0,0 +1,314 @@ +from collections import defaultdict +from dataclasses import dataclass +from heapq import heappop, heappush + +from sqlglot import Dialect +from sqlglot import expressions as exp +from sqlglot.helper import ensure_list + + +@dataclass(frozen=True) +class Insert: + """Indicates that a new node has been inserted""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Remove: + """Indicates that an existing node has been removed""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Move: + """Indicates that an existing node's position within the tree has changed""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Update: + """Indicates that an existing node has been updated""" + + source: exp.Expression + target: exp.Expression + + +@dataclass(frozen=True) +class Keep: + """Indicates that an existing node hasn't been changed""" + + source: exp.Expression + target: exp.Expression + + +def diff(source, target): + """ + Returns the list of changes between the source and the target expressions. + + Examples: + >>> diff(parse_one("a + b"), parse_one("a + c")) + [ + Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))), + Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))), + Keep( + source=(ADD this: ...), + target=(ADD this: ...) + ), + Keep( + source=(COLUMN this: (IDENTIFIER this: a, quoted: False)), + target=(COLUMN this: (IDENTIFIER this: a, quoted: False)) + ), + ] + + Args: + source (sqlglot.Expression): the source expression. + target (sqlglot.Expression): the target expression against which the diff should be calculated. + + Returns: + the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. + This list represents a sequence of steps needed to transform the source expression tree into the target one. + """ + return ChangeDistiller().diff(source.copy(), target.copy()) + + +LEAF_EXPRESSION_TYPES = ( + exp.Boolean, + exp.DataType, + exp.Identifier, + exp.Literal, +) + + +class ChangeDistiller: + """ + The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in + their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by + Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. + """ + + def __init__(self, f=0.6, t=0.6): + self.f = f + self.t = t + self._sql_generator = Dialect().generator() + + def diff(self, source, target): + self._source = source + self._target = target + self._source_index = {id(n[0]): n[0] for n in source.bfs()} + self._target_index = {id(n[0]): n[0] for n in target.bfs()} + self._unmatched_source_nodes = set(self._source_index) + self._unmatched_target_nodes = set(self._target_index) + self._bigram_histo_cache = {} + + matching_set = self._compute_matching_set() + return self._generate_edit_script(matching_set) + + def _generate_edit_script(self, matching_set): + edit_script = [] + for removed_node_id in self._unmatched_source_nodes: + edit_script.append(Remove(self._source_index[removed_node_id])) + for inserted_node_id in self._unmatched_target_nodes: + edit_script.append(Insert(self._target_index[inserted_node_id])) + for kept_source_node_id, kept_target_node_id in matching_set: + source_node = self._source_index[kept_source_node_id] + target_node = self._target_index[kept_target_node_id] + if ( + not isinstance(source_node, LEAF_EXPRESSION_TYPES) + or source_node == target_node + ): + edit_script.extend( + self._generate_move_edits(source_node, target_node, matching_set) + ) + edit_script.append(Keep(source_node, target_node)) + else: + edit_script.append(Update(source_node, target_node)) + + return edit_script + + def _generate_move_edits(self, source, target, matching_set): + source_args = [id(e) for e in _expression_only_args(source)] + target_args = [id(e) for e in _expression_only_args(target)] + + args_lcs = set( + _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set) + ) + + move_edits = [] + for a in source_args: + if a not in args_lcs and a not in self._unmatched_source_nodes: + move_edits.append(Move(self._source_index[a])) + + return move_edits + + def _compute_matching_set(self): + leaves_matching_set = self._compute_leaf_matching_set() + matching_set = leaves_matching_set.copy() + + ordered_unmatched_source_nodes = { + id(n[0]): None + for n in self._source.bfs() + if id(n[0]) in self._unmatched_source_nodes + } + ordered_unmatched_target_nodes = { + id(n[0]): None + for n in self._target.bfs() + if id(n[0]) in self._unmatched_target_nodes + } + + for source_node_id in ordered_unmatched_source_nodes: + for target_node_id in ordered_unmatched_target_nodes: + source_node = self._source_index[source_node_id] + target_node = self._target_index[target_node_id] + if _is_same_type(source_node, target_node): + source_leaf_ids = {id(l) for l in _get_leaves(source_node)} + target_leaf_ids = {id(l) for l in _get_leaves(target_node)} + + max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) + if max_leaves_num: + common_leaves_num = sum( + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 + for s, t in leaves_matching_set + ) + leaf_similarity_score = common_leaves_num / max_leaves_num + else: + leaf_similarity_score = 0.0 + + adjusted_t = ( + self.t + if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 + else 0.4 + ) + + if leaf_similarity_score >= 0.8 or ( + leaf_similarity_score >= adjusted_t + and self._dice_coefficient(source_node, target_node) >= self.f + ): + matching_set.add((source_node_id, target_node_id)) + self._unmatched_source_nodes.remove(source_node_id) + self._unmatched_target_nodes.remove(target_node_id) + ordered_unmatched_target_nodes.pop(target_node_id, None) + break + + return matching_set + + def _compute_leaf_matching_set(self): + candidate_matchings = [] + source_leaves = list(_get_leaves(self._source)) + target_leaves = list(_get_leaves(self._target)) + for source_leaf in source_leaves: + for target_leaf in target_leaves: + if _is_same_type(source_leaf, target_leaf): + similarity_score = self._dice_coefficient(source_leaf, target_leaf) + if similarity_score >= self.f: + heappush( + candidate_matchings, + ( + -similarity_score, + len(candidate_matchings), + source_leaf, + target_leaf, + ), + ) + + # Pick best matchings based on the highest score + matching_set = set() + while candidate_matchings: + _, _, source_leaf, target_leaf = heappop(candidate_matchings) + if ( + id(source_leaf) in self._unmatched_source_nodes + and id(target_leaf) in self._unmatched_target_nodes + ): + matching_set.add((id(source_leaf), id(target_leaf))) + self._unmatched_source_nodes.remove(id(source_leaf)) + self._unmatched_target_nodes.remove(id(target_leaf)) + + return matching_set + + def _dice_coefficient(self, source, target): + source_histo = self._bigram_histo(source) + target_histo = self._bigram_histo(target) + + total_grams = sum(source_histo.values()) + sum(target_histo.values()) + if not total_grams: + return 1.0 if source == target else 0.0 + + overlap_len = 0 + overlapping_grams = set(source_histo) & set(target_histo) + for g in overlapping_grams: + overlap_len += min(source_histo[g], target_histo[g]) + + return 2 * overlap_len / total_grams + + def _bigram_histo(self, expression): + if id(expression) in self._bigram_histo_cache: + return self._bigram_histo_cache[id(expression)] + + expression_str = self._sql_generator.generate(expression) + count = max(0, len(expression_str) - 1) + bigram_histo = defaultdict(int) + for i in range(count): + bigram_histo[expression_str[i : i + 2]] += 1 + + self._bigram_histo_cache[id(expression)] = bigram_histo + return bigram_histo + + +def _get_leaves(expression): + has_child_exprs = False + + for a in expression.args.values(): + nodes = ensure_list(a) + for node in nodes: + if isinstance(node, exp.Expression): + has_child_exprs = True + yield from _get_leaves(node) + + if not has_child_exprs: + yield expression + + +def _is_same_type(source, target): + if type(source) is type(target): + if isinstance(source, exp.Join): + return source.args.get("side") == target.args.get("side") + + if isinstance(source, exp.Anonymous): + return source.this == target.this + + return True + + return False + + +def _expression_only_args(expression): + args = [] + if expression: + for a in expression.args.values(): + args.extend(ensure_list(a)) + return [a for a in args if isinstance(a, exp.Expression)] + + +def _lcs(seq_a, seq_b, equal): + """Calculates the longest common subsequence""" + + len_a = len(seq_a) + len_b = len(seq_b) + lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)] + + for i in range(len_a + 1): + for j in range(len_b + 1): + if i == 0 or j == 0: + lcs_result[i][j] = [] + elif equal(seq_a[i - 1], seq_b[j - 1]): + lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] + else: + lcs_result[i][j] = ( + lcs_result[i - 1][j] + if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) + else lcs_result[i][j - 1] + ) + + return lcs_result[len_a][len_b] diff --git a/sqlglot/errors.py b/sqlglot/errors.py new file mode 100644 index 0000000..89aa935 --- /dev/null +++ b/sqlglot/errors.py @@ -0,0 +1,38 @@ +from enum import auto + +from sqlglot.helper import AutoName + + +class ErrorLevel(AutoName): + IGNORE = auto() # Ignore any parser errors + WARN = auto() # Log any parser errors with ERROR level + RAISE = auto() # Collect all parser errors and raise a single exception + IMMEDIATE = auto() # Immediately raise an exception on the first parser error + + +class SqlglotError(Exception): + pass + + +class UnsupportedError(SqlglotError): + pass + + +class ParseError(SqlglotError): + pass + + +class TokenError(SqlglotError): + pass + + +class OptimizeError(SqlglotError): + pass + + +def concat_errors(errors, maximum): + msg = [str(e) for e in errors[:maximum]] + remaining = len(errors) - maximum + if remaining > 0: + msg.append(f"... and {remaining} more") + return "\n\n".join(msg) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py new file mode 100644 index 0000000..a437431 --- /dev/null +++ b/sqlglot/executor/__init__.py @@ -0,0 +1,39 @@ +import logging +import time + +from sqlglot import parse_one +from sqlglot.executor.python import PythonExecutor +from sqlglot.optimizer import optimize +from sqlglot.planner import Plan + +logger = logging.getLogger("sqlglot") + + +def execute(sql, schema, read=None): + """ + Run a sql query against data. + + Args: + sql (str): a sql statement + schema (dict|sqlglot.optimizer.Schema): database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + Returns: + sqlglot.executor.Table: Simple columnar data structure. + """ + expression = parse_one(sql, read=read) + now = time.time() + expression = optimize(expression, schema) + logger.debug("Optimization finished: %f", time.time() - now) + logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) + plan = Plan(expression) + logger.debug("Logical Plan: %s", plan) + now = time.time() + result = PythonExecutor().execute(plan) + logger.debug("Query finished: %f", time.time() - now) + return result diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py new file mode 100644 index 0000000..457bea7 --- /dev/null +++ b/sqlglot/executor/context.py @@ -0,0 +1,68 @@ +from sqlglot.executor.env import ENV + + +class Context: + """ + Execution context for sql expressions. + + Context is used to hold relevant data tables which can then be queried on with eval. + + References to columns can either be scalar or vectors. When set_row is used, column references + evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient + evaluation of aggregation functions. + """ + + def __init__(self, tables, env=None): + """ + Args + tables (dict): table_name -> Table, representing the scope of the current execution context + env (Optional[dict]): dictionary of functions within the execution context + """ + self.tables = tables + self.range_readers = { + name: table.range_reader for name, table in self.tables.items() + } + self.row_readers = {name: table.reader for name, table in tables.items()} + self.env = {**(env or {}), "scope": self.row_readers} + + def eval(self, code): + return eval(code, ENV, self.env) + + def eval_tuple(self, codes): + return tuple(self.eval(code) for code in codes) + + def __iter__(self): + return self.table_iter(list(self.tables)[0]) + + def table_iter(self, table): + self.env["scope"] = self.row_readers + + for reader in self.tables[table]: + yield reader, self + + def sort(self, table, key): + table = self.tables[table] + + def sort_key(row): + table.reader.row = row + return self.eval_tuple(key) + + table.rows.sort(key=sort_key) + + def set_row(self, table, row): + self.row_readers[table].row = row + self.env["scope"] = self.row_readers + + def set_index(self, table, index): + self.row_readers[table].row = self.tables[table].rows[index] + self.env["scope"] = self.row_readers + + def set_range(self, table, start, end): + self.range_readers[table].range = range(start, end) + self.env["scope"] = self.range_readers + + def __getitem__(self, table): + return self.env["scope"][table] + + def __contains__(self, table): + return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py new file mode 100644 index 0000000..72b0558 --- /dev/null +++ b/sqlglot/executor/env.py @@ -0,0 +1,32 @@ +import datetime +import re +import statistics + + +class reverse_key: + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + return other.obj == self.obj + + def __lt__(self, other): + return other.obj < self.obj + + +ENV = { + "__builtins__": {}, + "datetime": datetime, + "locals": locals, + "re": re, + "float": float, + "int": int, + "str": str, + "desc": reverse_key, + "SUM": sum, + "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, + "COUNT": lambda acc: sum(1 for e in acc if e is not None), + "MAX": max, + "MIN": min, + "POW": pow, +} diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py new file mode 100644 index 0000000..388a419 --- /dev/null +++ b/sqlglot/executor/python.py @@ -0,0 +1,360 @@ +import ast +import collections +import itertools + +from sqlglot import exp, planner +from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.executor.context import Context +from sqlglot.executor.env import ENV +from sqlglot.executor.table import Table +from sqlglot.generator import Generator +from sqlglot.helper import csv_reader +from sqlglot.tokens import Tokenizer + + +class PythonExecutor: + def __init__(self, env=None): + self.generator = Python().generator(identify=True) + self.env = {**ENV, **(env or {})} + + def execute(self, plan): + running = set() + finished = set() + queue = set(plan.leaves) + contexts = {} + + while queue: + node = queue.pop() + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + running.add(node) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + else: + raise NotImplementedError + + running.remove(node) + finished.add(node) + + for dep in node.dependents: + if dep not in running and all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + + root = plan.root + return contexts[root].tables[root.name] + + def generate(self, expression): + """Convert a SQL expression into literal Python code and compile it into bytecode.""" + if not expression: + return None + + sql = self.generator.generate(expression) + return compile(sql, sql, "eval", optimize=2) + + def generate_tuple(self, expressions): + """Convert an array of SQL expressions into tuple of Python byte code.""" + if not expressions: + return tuple() + return tuple(self.generate(expression) for expression in expressions) + + def context(self, tables): + return Context(tables, env=self.env) + + def table(self, expressions): + return Table(expression.alias_or_name for expression in expressions) + + def scan(self, step, context): + if hasattr(step, "source"): + source = step.source + + if isinstance(source, exp.Expression): + source = source.this.name or source.alias + else: + source = step.name + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) + + if source in context: + if not projections and not condition: + return self.context({step.name: context.tables[source]}) + table_iter = context.table_iter(source) + else: + table_iter = self.scan_csv(step) + + if projections: + sink = self.table(step.projections) + elif source in context: + sink = Table(context[source].columns) + else: + sink = None + + for reader, ctx in table_iter: + if sink is None: + sink = Table(ctx[source].columns) + + if condition and not ctx.eval(condition): + continue + + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) + + if len(sink) >= step.limit: + break + + return self.context({step.name: sink}) + + def scan_csv(self, step): + source = step.source + alias = source.alias + + with csv_reader(source.this) as reader: + columns = next(reader) + table = Table(columns) + context = self.context({alias: table}) + types = [] + + for row in reader: + if not types: + for v in row: + try: + types.append(type(ast.literal_eval(v))) + except (ValueError, SyntaxError): + types.append(str) + context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) + yield context[alias], context + + def join(self, step, context): + source = step.name + + join_context = self.context({source: context.tables[source]}) + + def merge_context(ctx, table): + # create a new context where all existing tables are mapped to a new one + return self.context({name: table for name in ctx.tables}) + + for name, join in step.joins.items(): + join_context = self.context( + {**join_context.tables, name: context.tables[name]} + ) + + if join.get("source_key"): + table = self.hash_join(join, source, name, join_context) + else: + table = self.nested_loop_join(join, source, name, join_context) + + join_context = merge_context(join_context, table) + + # apply projections or conditions + context = self.scan(step, join_context) + + # use the scan context since it returns a single table + # otherwise there are no projections so all other tables are still in scope + if step.projections: + return context + + return merge_context(join_context, context.tables[source]) + + def nested_loop_join(self, _join, a, b, context): + table = Table(context.tables[a].columns + context.tables[b].columns) + + for reader_a, _ in context.table_iter(a): + for reader_b, _ in context.table_iter(b): + table.append(reader_a.row + reader_b.row) + + return table + + def hash_join(self, join, a, b, context): + a_key = self.generate_tuple(join["source_key"]) + b_key = self.generate_tuple(join["join_key"]) + + results = collections.defaultdict(lambda: ([], [])) + + for reader, ctx in context.table_iter(a): + results[ctx.eval_tuple(a_key)][0].append(reader.row) + for reader, ctx in context.table_iter(b): + results[ctx.eval_tuple(b_key)][1].append(reader.row) + + table = Table(context.tables[a].columns + context.tables[b].columns) + for a_group, b_group in results.values(): + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def sort_merge_join(self, join, a, b, context): + a_key = self.generate_tuple(join["source_key"]) + b_key = self.generate_tuple(join["join_key"]) + + context.sort(a, a_key) + context.sort(b, b_key) + + a_i = 0 + b_i = 0 + a_n = len(context.tables[a]) + b_n = len(context.tables[b]) + + table = Table(context.tables[a].columns + context.tables[b].columns) + + def get_key(source, key, i): + context.set_index(source, i) + return context.eval_tuple(key) + + while a_i < a_n and b_i < b_n: + key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) + + a_group = [] + + while a_i < a_n and key == get_key(a, a_key, a_i): + a_group.append(context[a].row) + a_i += 1 + + b_group = [] + + while b_i < b_n and key == get_key(b, b_key, b_i): + b_group.append(context[b].row) + b_i += 1 + + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def aggregate(self, step, context): + source = step.source + group_by = self.generate_tuple(step.group) + aggregations = self.generate_tuple(step.aggregations) + operands = self.generate_tuple(step.operands) + + context.sort(source, group_by) + + if step.operands: + source_table = context.tables[source] + operand_table = Table( + source_table.columns + self.table(step.operands).columns + ) + + for reader, ctx in context: + operand_table.append(reader.row + ctx.eval_tuple(operands)) + + context = self.context({source: operand_table}) + + group = None + start = 0 + end = 1 + length = len(context.tables[source]) + table = self.table(step.group + step.aggregations) + + for i in range(length): + context.set_index(source, i) + key = context.eval_tuple(group_by) + group = key if group is None else group + end += 1 + + if i == length - 1: + context.set_range(source, start, end - 1) + elif key != group: + context.set_range(source, start, end - 2) + else: + continue + + table.append(group + context.eval_tuple(aggregations)) + group = key + start = end - 2 + + return self.scan(step, self.context({source: table})) + + def sort(self, step, context): + table = list(context.tables)[0] + key = self.generate_tuple(step.key) + context.sort(table, key) + return self.scan(step, context) + + +def _cast_py(self, expression): + to = expression.args["to"].this + this = self.sql(expression, "this") + + if to == exp.DataType.Type.DATE: + return f"datetime.date.fromisoformat({this})" + if to == exp.DataType.Type.TEXT: + return f"str({this})" + raise NotImplementedError + + +def _column_py(self, expression): + table = self.sql(expression, "table") + this = self.sql(expression, "this") + return f"scope[{table}][{this}]" + + +def _interval_py(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() + if unit == "DAY": + return f"datetime.timedelta(days=float({this}))" + raise NotImplementedError + + +def _like_py(self, expression): + this = self.sql(expression, "this") + expression = self.sql(expression, "expression") + return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})""" + + +def _ordered_py(self, expression): + this = self.sql(expression, "this") + desc = expression.args.get("desc") + return f"desc({this})" if desc else this + + +class Python(Dialect): + class Tokenizer(Tokenizer): + ESCAPE = "\\" + + class Generator(Generator): + TRANSFORMS = { + exp.Alias: lambda self, e: self.sql(e.this), + exp.Array: inline_array_sql, + exp.And: lambda self, e: self.binary(e, "and"), + exp.Cast: _cast_py, + exp.Column: _column_py, + exp.EQ: lambda self, e: self.binary(e, "=="), + exp.Interval: _interval_py, + exp.Is: lambda self, e: self.binary(e, "is"), + exp.Like: _like_py, + exp.Not: lambda self, e: f"not {self.sql(e.this)}", + exp.Null: lambda *_: "None", + exp.Or: lambda self, e: self.binary(e, "or"), + exp.Ordered: _ordered_py, + exp.Star: lambda *_: "1", + } + + def case_sql(self, expression): + this = self.sql(expression, "this") + chain = self.sql(expression, "default") or "None" + + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" + + return chain diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py new file mode 100644 index 0000000..6df49f7 --- /dev/null +++ b/sqlglot/executor/table.py @@ -0,0 +1,81 @@ +class Table: + def __init__(self, *columns, rows=None): + self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + self.rows = rows or [] + if rows: + assert len(rows[0]) == len(self.columns) + self.reader = RowReader(self.columns) + self.range_reader = RangeReader(self) + + def append(self, row): + assert len(row) == len(self.columns) + self.rows.append(row) + + def pop(self): + self.rows.pop() + + @property + def width(self): + return len(self.columns) + + def __len__(self): + return len(self.rows) + + def __iter__(self): + return TableIter(self) + + def __getitem__(self, index): + self.reader.row = self.rows[index] + return self.reader + + def __repr__(self): + widths = {column: len(column) for column in self.columns} + lines = [" ".join(column for column in self.columns)] + + for i, row in enumerate(self): + if i > 10: + break + + lines.append( + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] + for column in self.columns + ) + ) + return "\n".join(lines) + + +class TableIter: + def __init__(self, table): + self.table = table + self.index = -1 + + def __iter__(self): + return self + + def __next__(self): + self.index += 1 + if self.index < len(self.table): + return self.table[self.index] + raise StopIteration + + +class RangeReader: + def __init__(self, table): + self.table = table + self.range = range(0) + + def __len__(self): + return len(self.range) + + def __getitem__(self, column): + return (self.table[i][column] for i in self.range) + + +class RowReader: + def __init__(self, columns): + self.columns = {column: i for i, column in enumerate(columns)} + self.row = None + + def __getitem__(self, column): + return self.row[self.columns[column]] diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py new file mode 100644 index 0000000..7acc63d --- /dev/null +++ b/sqlglot/expressions.py @@ -0,0 +1,2945 @@ +import inspect +import re +import sys +from collections import deque +from copy import deepcopy +from enum import auto + +from sqlglot.errors import ParseError +from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list + + +class _Expression(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + klass.key = clsname.lower() + return klass + + +class Expression(metaclass=_Expression): + """ + The base class for all expressions in a syntax tree. + + Attributes: + arg_types (dict): determines arguments supported by this expression. + The key in a dictionary defines a unique key of an argument using + which the argument's value can be retrieved. The value is a boolean + flag which indicates whether the argument's value is required (True) + or optional (False). + """ + + key = None + arg_types = {"this": True} + __slots__ = ("args", "parent", "arg_key") + + def __init__(self, **args): + self.args = args + self.parent = None + self.arg_key = None + + for arg_key, value in self.args.items(): + self._set_parent(arg_key, value) + + def __eq__(self, other): + return type(self) is type(other) and _norm_args(self) == _norm_args(other) + + def __hash__(self): + return hash( + ( + self.key, + tuple( + (k, tuple(v) if isinstance(v, list) else v) + for k, v in _norm_args(self).items() + ), + ) + ) + + @property + def this(self): + return self.args.get("this") + + @property + def expression(self): + return self.args.get("expression") + + @property + def expressions(self): + return self.args.get("expressions") or [] + + def text(self, key): + field = self.args.get(key) + if isinstance(field, str): + return field + if isinstance(field, (Identifier, Literal, Var)): + return field.this + return "" + + @property + def is_string(self): + return isinstance(self, Literal) and self.args["is_string"] + + @property + def is_number(self): + return isinstance(self, Literal) and not self.args["is_string"] + + @property + def is_int(self): + if self.is_number: + try: + int(self.name) + return True + except ValueError: + pass + return False + + @property + def alias(self): + if isinstance(self.args.get("alias"), TableAlias): + return self.args["alias"].name + return self.text("alias") + + @property + def name(self): + return self.text("this") + + @property + def alias_or_name(self): + return self.alias or self.name + + def __deepcopy__(self, memo): + return self.__class__(**deepcopy(self.args)) + + def copy(self): + new = deepcopy(self) + for item, parent, _ in new.bfs(): + if isinstance(item, Expression) and parent: + item.parent = parent + return new + + def set(self, arg_key, value): + """ + Sets `arg` to `value`. + + Args: + arg_key (str): name of the expression arg + value: value to set the arg to. + """ + self.args[arg_key] = value + self._set_parent(arg_key, value) + + def _set_parent(self, arg_key, value): + if isinstance(value, Expression): + value.parent = self + value.arg_key = arg_key + elif isinstance(value, list): + for v in value: + if isinstance(v, Expression): + v.parent = self + v.arg_key = arg_key + + @property + def depth(self): + """ + Returns the depth of this tree. + """ + if self.parent: + return self.parent.depth + 1 + return 0 + + def find(self, *expression_types, bfs=True): + """ + Returns the first node in this tree which matches at least one of + the specified types. + + Args: + expression_types (type): the expression type(s) to match. + + Returns: + the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(self.find_all(*expression_types, bfs=bfs), None) + + def find_all(self, *expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this tree and only + yields those that match at least one of the specified expression types. + + Args: + expression_types (type): the expression type(s) to match. + + Returns: + the generator object. + """ + for expression, _, _ in self.walk(bfs=bfs): + if isinstance(expression, expression_types): + yield expression + + def find_ancestor(self, *expression_types): + """ + Returns a nearest parent matching expression_types. + + Args: + expression_types (type): the expression type(s) to match. + + Returns: + the parent node + """ + ancestor = self.parent + while ancestor and not isinstance(ancestor, expression_types): + ancestor = ancestor.parent + return ancestor + + @property + def parent_select(self): + """ + Returns the parent select statement. + """ + return self.find_ancestor(Select) + + def walk(self, bfs=True): + """ + Returns a generator object which visits all nodes in this tree. + + Args: + bfs (bool): if set to True the BFS traversal order will be applied, + otherwise the DFS traversal will be used instead. + + Returns: + the generator object. + """ + if bfs: + yield from self.bfs() + else: + yield from self.dfs() + + def dfs(self, parent=None, key=None, prune=None): + """ + Returns a generator object which visits all nodes in this tree in + the DFS (Depth-first) order. + + Returns: + the generator object. + """ + parent = parent or self.parent + yield self, parent, key + if prune and prune(self, parent, key): + return + + for k, v in self.args.items(): + nodes = ensure_list(v) + + for node in nodes: + if isinstance(node, Expression): + yield from node.dfs(self, k, prune) + + def bfs(self, prune=None): + """ + Returns a generator object which visits all nodes in this tree in + the BFS (Breadth-first) order. + + Returns: + the generator object. + """ + queue = deque([(self, self.parent, None)]) + + while queue: + item, parent, key = queue.popleft() + + yield item, parent, key + if prune and prune(item, parent, key): + continue + + if isinstance(item, Expression): + for k, v in item.args.items(): + nodes = ensure_list(v) + + for node in nodes: + if isinstance(node, Expression): + queue.append((node, item, k)) + + def unnest(self): + """ + Returns the first non parenthesis child or self. + """ + expression = self + while isinstance(expression, Paren): + expression = expression.this + return expression + + def unnest_operands(self): + """ + Returns unnested operands as a tuple. + """ + return tuple(arg.unnest() for arg in self.args.values() if arg) + + def flatten(self, unnest=True): + """ + Returns a generator which yields child nodes who's parents are the same class. + + A AND B AND C -> [A, B, C] + """ + for node, _, _ in self.dfs( + prune=lambda n, p, *_: p and not isinstance(n, self.__class__) + ): + if not isinstance(node, self.__class__): + yield node.unnest() if unnest else node + + def __str__(self): + return self.sql() + + def __repr__(self): + return self.to_s() + + def sql(self, dialect=None, **opts): + """ + Returns SQL string representation of this tree. + + Args + dialect (str): the dialect of the output SQL string + (eg. "spark", "hive", "presto", "mysql"). + opts (dict): other :class:`~sqlglot.generator.Generator` options. + + Returns + the SQL string. + """ + from sqlglot.dialects import Dialect + + return Dialect.get_or_raise(dialect)().generate(self, **opts) + + def to_s(self, hide_missing=True, level=0): + indent = "" if not level else "\n" + indent += "".join([" "] * level) + left = f"({self.key.upper()} " + + args = { + k: ", ".join( + v.to_s(hide_missing=hide_missing, level=level + 1) + if hasattr(v, "to_s") + else str(v) + for v in ensure_list(vs) + if v is not None + ) + for k, vs in self.args.items() + } + args = {k: v for k, v in args.items() if v or not hide_missing} + + right = ", ".join(f"{k}: {v}" for k, v in args.items()) + right += ")" + + return indent + left + right + + def transform(self, fun, *args, copy=True, **kwargs): + """ + Recursively visits all tree nodes (excluding already transformed ones) + and applies the given transformation function to each node. + + Args: + fun (function): a function which takes a node as an argument and returns a + new transformed node or the same node without modifications. + copy (bool): if set to True a new tree instance is constructed, otherwise the tree is + modified in place. + + Returns: + the transformed tree. + """ + node = self.copy() if copy else self + new_node = fun(node, *args, **kwargs) + + if new_node is None: + raise ValueError("A transformed node cannot be None") + if not isinstance(new_node, Expression): + return new_node + if new_node is not node: + new_node.parent = node.parent + return new_node + + replace_children( + new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs) + ) + return new_node + + def replace(self, expression): + """ + Swap out this expression with a new expression. + + For example:: + + >>> tree = Select().select("x").from_("tbl") + >>> tree.find(Column).replace(Column(this="y")) + (COLUMN this: y) + >>> tree.sql() + 'SELECT y FROM tbl' + + Args: + expression (Expression): new node + + Returns : + the new expression or expressions + """ + if not self.parent: + return expression + + parent = self.parent + self.parent = None + + replace_children(parent, lambda child: expression if child is self else child) + return expression + + def assert_is(self, type_): + """ + Assert that this `Expression` is an instance of `type_`. + + If it is NOT an instance of `type_`, this raises an assertion error. + Otherwise, this returns this expression. + + Examples: + This is useful for type security in chained expressions: + + >>> import sqlglot + >>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql() + 'SELECT x, z FROM y' + """ + assert isinstance(self, type_) + return self + + +class Condition(Expression): + def and_(self, *expressions, dialect=None, **opts): + """ + AND this condition with one or multiple expressions. + + Example: + >>> condition("x=1").and_("y=1").sql() + 'x = 1 AND y = 1' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + And: the new condition. + """ + return and_(self, *expressions, dialect=dialect, **opts) + + def or_(self, *expressions, dialect=None, **opts): + """ + OR this condition with one or multiple expressions. + + Example: + >>> condition("x=1").or_("y=1").sql() + 'x = 1 OR y = 1' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Or: the new condition. + """ + return or_(self, *expressions, dialect=dialect, **opts) + + def not_(self): + """ + Wrap this condition with NOT. + + Example: + >>> condition("x=1").not_().sql() + 'NOT x = 1' + + Returns: + Not: the new condition. + """ + return not_(self) + + +class Predicate(Condition): + """Relationships like x = y, x > 1, x >= y.""" + + +class DerivedTable(Expression): + @property + def alias_column_names(self): + table_alias = self.args.get("alias") + if not table_alias: + return [] + column_list = table_alias.assert_is(TableAlias).args.get("columns") or [] + return [c.name for c in column_list] + + @property + def selects(self): + alias = self.args.get("alias") + + if alias: + return alias.columns + return [] + + @property + def named_selects(self): + return [select.alias_or_name for select in self.selects] + + +class Annotation(Expression): + arg_types = { + "this": True, + "expression": True, + } + + +class Cache(Expression): + arg_types = { + "with": False, + "this": True, + "lazy": False, + "options": False, + "expression": False, + } + + +class Uncache(Expression): + arg_types = {"this": True, "exists": False} + + +class Create(Expression): + arg_types = { + "with": False, + "this": True, + "kind": True, + "expression": False, + "exists": False, + "properties": False, + "temporary": False, + "replace": False, + "unique": False, + } + + +class CharacterSet(Expression): + arg_types = {"this": True, "default": False} + + +class With(Expression): + arg_types = {"expressions": True, "recursive": False} + + +class WithinGroup(Expression): + arg_types = {"this": True, "expression": False} + + +class CTE(DerivedTable): + arg_types = {"this": True, "alias": True} + + +class TableAlias(Expression): + arg_types = {"this": False, "columns": False} + + @property + def columns(self): + return self.args.get("columns") or [] + + +class BitString(Condition): + pass + + +class Column(Condition): + arg_types = {"this": True, "table": False} + + @property + def table(self): + return self.text("table") + + +class ColumnDef(Expression): + arg_types = { + "this": True, + "kind": True, + "constraints": False, + } + + +class ColumnConstraint(Expression): + arg_types = {"this": False, "kind": True} + + +class AutoIncrementColumnConstraint(Expression): + pass + + +class CheckColumnConstraint(Expression): + pass + + +class CollateColumnConstraint(Expression): + pass + + +class CommentColumnConstraint(Expression): + pass + + +class DefaultColumnConstraint(Expression): + pass + + +class NotNullColumnConstraint(Expression): + pass + + +class PrimaryKeyColumnConstraint(Expression): + pass + + +class UniqueColumnConstraint(Expression): + pass + + +class Constraint(Expression): + arg_types = {"this": True, "expressions": True} + + +class Delete(Expression): + arg_types = {"with": False, "this": True, "where": False} + + +class Drop(Expression): + arg_types = {"this": False, "kind": False, "exists": False} + + +class Filter(Expression): + arg_types = {"this": True, "expression": True} + + +class Check(Expression): + pass + + +class ForeignKey(Expression): + arg_types = { + "expressions": True, + "reference": False, + "delete": False, + "update": False, + } + + +class Unique(Expression): + arg_types = {"expressions": True} + + +class From(Expression): + arg_types = {"expressions": True} + + +class Having(Expression): + pass + + +class Hint(Expression): + arg_types = {"expressions": True} + + +class Identifier(Expression): + arg_types = {"this": True, "quoted": False} + + @property + def quoted(self): + return bool(self.args.get("quoted")) + + def __eq__(self, other): + return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg( + other.this + ) + + def __hash__(self): + return hash((self.key, self.this.lower())) + + +class Index(Expression): + arg_types = {"this": False, "table": False, "where": False, "columns": False} + + +class Insert(Expression): + arg_types = { + "with": False, + "this": True, + "expression": True, + "overwrite": False, + "exists": False, + "partition": False, + } + + +# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html +class Introducer(Expression): + arg_types = {"this": True, "expression": True} + + +class Partition(Expression): + pass + + +class Fetch(Expression): + arg_types = {"direction": False, "count": True} + + +class Group(Expression): + arg_types = { + "expressions": False, + "grouping_sets": False, + "cube": False, + "rollup": False, + } + + +class Lambda(Expression): + arg_types = {"this": True, "expressions": True} + + +class Limit(Expression): + arg_types = {"this": False, "expression": True} + + +class Literal(Condition): + arg_types = {"this": True, "is_string": True} + + def __eq__(self, other): + return ( + isinstance(other, Literal) + and self.this == other.this + and self.args["is_string"] == other.args["is_string"] + ) + + def __hash__(self): + return hash((self.key, self.this, self.args["is_string"])) + + @classmethod + def number(cls, number): + return cls(this=str(number), is_string=False) + + @classmethod + def string(cls, string): + return cls(this=str(string), is_string=True) + + +class Join(Expression): + arg_types = { + "this": True, + "on": False, + "side": False, + "kind": False, + "using": False, + } + + @property + def kind(self): + return self.text("kind").upper() + + @property + def side(self): + return self.text("side").upper() + + def on(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the ON expressions. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("JOIN x", into=Join).on("y = 1").sql() + 'JOIN x ON y = 1' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append (bool): if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Join: the modified join expression. + """ + join = _apply_conjunction_builder( + *expressions, + instance=self, + arg="on", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + if join.kind == "CROSS": + join.set("kind", None) + + return join + + +class Lateral(DerivedTable): + arg_types = {"this": True, "outer": False, "alias": False} + + +# Clickhouse FROM FINAL modifier +# https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier +class Final(Expression): + pass + + +class Offset(Expression): + arg_types = {"this": False, "expression": True} + + +class Order(Expression): + arg_types = {"this": False, "expressions": True} + + +# hive specific sorts +# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+SortBy +class Cluster(Order): + pass + + +class Distribute(Order): + pass + + +class Sort(Order): + pass + + +class Ordered(Expression): + arg_types = {"this": True, "desc": True, "nulls_first": True} + + +class Properties(Expression): + arg_types = {"expressions": True} + + +class Property(Expression): + arg_types = {"this": True, "value": True} + + +class TableFormatProperty(Property): + pass + + +class PartitionedByProperty(Property): + pass + + +class FileFormatProperty(Property): + pass + + +class LocationProperty(Property): + pass + + +class EngineProperty(Property): + pass + + +class AutoIncrementProperty(Property): + pass + + +class CharacterSetProperty(Property): + arg_types = {"this": True, "value": True, "default": True} + + +class CollateProperty(Property): + pass + + +class SchemaCommentProperty(Property): + pass + + +class AnonymousProperty(Property): + pass + + +class Qualify(Expression): + pass + + +class Reference(Expression): + arg_types = {"this": True, "expressions": True} + + +class Table(Expression): + arg_types = {"this": True, "db": False, "catalog": False} + + +class Tuple(Expression): + arg_types = {"expressions": False} + + +class Subqueryable: + def subquery(self, alias=None, copy=True): + """ + Convert this expression to an aliased expression that can be used as a Subquery. + + Example: + >>> subquery = Select().select("x").from_("tbl").subquery() + >>> Select().select("x").from_(subquery).sql() + 'SELECT x FROM (SELECT x FROM tbl)' + + Args: + alias (str or Identifier): an optional alias for the subquery + copy (bool): if `False`, modify this expression instance in-place. + + Returns: + Alias: the subquery + """ + instance = _maybe_copy(self, copy) + return Subquery( + this=instance, + alias=TableAlias(this=to_identifier(alias)), + ) + + @property + def ctes(self): + with_ = self.args.get("with") + if not with_: + return [] + return with_.expressions + + def with_( + self, + alias, + as_, + recursive=None, + append=True, + dialect=None, + copy=True, + **opts, + ): + """ + Append to or set the common table expressions. + + Example: + >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql() + 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' + + Args: + alias (str or Expression): the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_ (str or Expression): the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + alias_expression = maybe_parse( + alias, + dialect=dialect, + into=TableAlias, + **opts, + ) + as_expression = maybe_parse( + as_, + dialect=dialect, + **opts, + ) + cte = CTE( + this=as_expression, + alias=alias_expression, + ) + return _apply_child_list_builder( + cte, + instance=self, + arg="with", + append=append, + copy=copy, + into=With, + properties={"recursive": recursive or False}, + ) + + +QUERY_MODIFIERS = { + "laterals": False, + "joins": False, + "where": False, + "group": False, + "having": False, + "qualify": False, + "window": False, + "distribute": False, + "sort": False, + "cluster": False, + "order": False, + "limit": False, + "offset": False, +} + + +class Union(Subqueryable, Expression): + arg_types = { + "with": False, + "this": True, + "expression": True, + "distinct": False, + **QUERY_MODIFIERS, + } + + @property + def named_selects(self): + return self.args["this"].unnest().named_selects + + @property + def left(self): + return self.this + + @property + def right(self): + return self.expression + + +class Except(Union): + pass + + +class Intersect(Union): + pass + + +class Unnest(DerivedTable): + arg_types = { + "expressions": True, + "ordinality": False, + "alias": False, + } + + +class Update(Expression): + arg_types = { + "with": False, + "this": True, + "expressions": True, + "from": False, + "where": False, + } + + +class Values(Expression): + arg_types = {"expressions": True} + + +class Var(Expression): + pass + + +class Schema(Expression): + arg_types = {"this": False, "expressions": True} + + +class Select(Subqueryable, Expression): + arg_types = { + "with": False, + "expressions": False, + "hint": False, + "distinct": False, + "from": False, + **QUERY_MODIFIERS, + } + + def from_(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the FROM expression. + + Example: + >>> Select().from_("tbl").select("x").sql() + 'SELECT x FROM tbl' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `From` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `From`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `From` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="from", + append=append, + copy=copy, + prefix="FROM", + into=From, + dialect=dialect, + **opts, + ) + + def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the GROUP BY expression. + + Example: + >>> Select().from_("tbl").select("x", "COUNT(1)").group_by("x").sql() + 'SELECT x, COUNT(1) FROM tbl GROUP BY x' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Group`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Group` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="group", + append=append, + copy=copy, + prefix="GROUP BY", + into=Group, + dialect=dialect, + **opts, + ) + + def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the ORDER BY expression. + + Example: + >>> Select().from_("tbl").select("x").order_by("x DESC").sql() + 'SELECT x FROM tbl ORDER BY x DESC' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Order`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="order", + append=append, + copy=copy, + prefix="ORDER BY", + into=Order, + dialect=dialect, + **opts, + ) + + def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the SORT BY expression. + + Example: + >>> Select().from_("tbl").select("x").sort_by("x DESC").sql() + 'SELECT x FROM tbl SORT BY x DESC' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `SORT`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="sort", + append=append, + copy=copy, + prefix="SORT BY", + into=Sort, + dialect=dialect, + **opts, + ) + + def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the CLUSTER BY expression. + + Example: + >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql() + 'SELECT x FROM tbl CLUSTER BY x DESC' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Cluster`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="cluster", + append=append, + copy=copy, + prefix="CLUSTER BY", + into=Cluster, + dialect=dialect, + **opts, + ) + + def limit(self, expression, dialect=None, copy=True, **opts): + """ + Set the LIMIT expression. + + Example: + >>> Select().from_("tbl").select("x").limit(10).sql() + 'SELECT x FROM tbl LIMIT 10' + + Args: + expression (str or int or Expression): the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="limit", + into=Limit, + prefix="LIMIT", + dialect=dialect, + copy=copy, + **opts, + ) + + def offset(self, expression, dialect=None, copy=True, **opts): + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").offset(10).sql() + 'SELECT x FROM tbl OFFSET 10' + + Args: + expression (str or int or Expression): the SQL code string to parse. + This can also be an integer. + If a `Offset` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Offset`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="offset", + into=Offset, + prefix="OFFSET", + dialect=dialect, + copy=copy, + **opts, + ) + + def select(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the SELECT expressions. + + Example: + >>> Select().select("x", "y").sql() + 'SELECT x, y' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="expressions", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the LATERAL expressions. + + Example: + >>> Select().select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl").sql() + 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="laterals", + append=append, + into=Lateral, + prefix="LATERAL VIEW", + dialect=dialect, + copy=copy, + **opts, + ) + + def join( + self, + expression, + on=None, + append=True, + join_type=None, + join_alias=None, + dialect=None, + copy=True, + **opts, + ): + """ + Append to or set the JOIN expressions. + + Example: + >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql() + 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y' + + Use `join_type` to change the type of join: + + >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql() + 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' + + Args: + expression (str or Expression): the SQL code string to parse. + If an `Expression` instance is passed, it will be used as-is. + on (str or Expression): optionally specify the join criteria as a SQL string. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + join_type (str): If set, alter the parsed join type + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + parse_args = {"dialect": dialect, **opts} + + try: + expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) + except ParseError: + expression = maybe_parse(expression, into=(Join, Expression), **parse_args) + + join = expression if isinstance(expression, Join) else Join(this=expression) + + if isinstance(join.this, Select): + join.this.replace(join.this.subquery()) + + if join_type: + side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + if side: + join.set("side", side.text) + if kind: + join.set("kind", kind.text) + + if on: + on = and_(*ensure_list(on), dialect=dialect, **opts) + join.set("on", on) + + if join_alias: + join.set("this", alias_(join.args["this"], join_alias, table=True)) + return _apply_list_builder( + join, + instance=self, + arg="joins", + append=append, + copy=copy, + **opts, + ) + + def where(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the WHERE expressions. + + Example: + >>> Select().select("x").from_("tbl").where("x = 'a' OR x < 'b'").sql() + "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append (bool): if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def having(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the HAVING expressions. + + Example: + >>> Select().select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 3").sql() + 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append (bool): if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="having", + append=append, + into=Having, + dialect=dialect, + copy=copy, + **opts, + ) + + def distinct(self, distinct=True, copy=True): + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").distinct().sql() + 'SELECT DISTINCT x FROM tbl' + + Args: + distinct (bool): whether the Select should be distinct + copy (bool): if `False`, modify this expression instance in-place. + + Returns: + Select: the modified expression. + """ + instance = _maybe_copy(self, copy) + instance.set("distinct", Distinct() if distinct else None) + return instance + + def ctas(self, table, properties=None, dialect=None, copy=True, **opts): + """ + Convert this expression to a CREATE TABLE AS statement. + + Example: + >>> Select().select("*").from_("tbl").ctas("x").sql() + 'CREATE TABLE x AS SELECT * FROM tbl' + + Args: + table (str or Expression): the SQL code string to parse as the table name. + If another `Expression` instance is passed, it will be used as-is. + properties (dict): an optional mapping of table properties + dialect (str): the dialect used to parse the input table. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input table. + + Returns: + Create: the CREATE TABLE AS expression + """ + instance = _maybe_copy(self, copy) + table_expression = maybe_parse( + table, + into=Table, + dialect=dialect, + **opts, + ) + properties_expression = None + if properties: + properties_str = " ".join( + [ + f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" + for k, v in properties.items() + ] + ) + properties_expression = maybe_parse( + properties_str, + into=Properties, + dialect=dialect, + **opts, + ) + + return Create( + this=table_expression, + kind="table", + expression=instance, + properties=properties_expression, + ) + + @property + def named_selects(self): + return [e.alias_or_name for e in self.expressions if e.alias_or_name] + + @property + def selects(self): + return self.expressions + + +class Subquery(DerivedTable): + arg_types = { + "this": True, + "alias": False, + **QUERY_MODIFIERS, + } + + def unnest(self): + """ + Returns the first non subquery. + """ + expression = self + while isinstance(expression, Subquery): + expression = expression.this + return expression + + +class TableSample(Expression): + arg_types = { + "this": False, + "method": False, + "bucket_numerator": False, + "bucket_denominator": False, + "bucket_field": False, + "percent": False, + "rows": False, + "size": False, + } + + +class Window(Expression): + arg_types = { + "this": True, + "partition_by": False, + "order": False, + "spec": False, + "alias": False, + } + + +class WindowSpec(Expression): + arg_types = { + "kind": False, + "start": False, + "start_side": False, + "end": False, + "end_side": False, + } + + +class Where(Expression): + pass + + +class Star(Expression): + arg_types = {"except": False, "replace": False} + + @property + def name(self): + return "*" + + +class Placeholder(Expression): + arg_types = {} + + +class Null(Condition): + arg_types = {} + + +class Boolean(Condition): + pass + + +class DataType(Expression): + arg_types = { + "this": True, + "expressions": False, + "nested": False, + } + + class Type(AutoName): + CHAR = auto() + NCHAR = auto() + VARCHAR = auto() + NVARCHAR = auto() + TEXT = auto() + BINARY = auto() + INT = auto() + TINYINT = auto() + SMALLINT = auto() + BIGINT = auto() + FLOAT = auto() + DOUBLE = auto() + DECIMAL = auto() + BOOLEAN = auto() + JSON = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + DATE = auto() + DATETIME = auto() + ARRAY = auto() + MAP = auto() + UUID = auto() + GEOGRAPHY = auto() + STRUCT = auto() + NULLABLE = auto() + + @classmethod + def build(cls, dtype, **kwargs): + return DataType( + this=dtype + if isinstance(dtype, DataType.Type) + else DataType.Type[dtype.upper()], + **kwargs, + ) + + +class StructKwarg(Expression): + arg_types = {"this": True, "expression": True} + + +# WHERE x EXISTS|ALL|ANY|SOME(SELECT ...) +class SubqueryPredicate(Predicate): + pass + + +class All(SubqueryPredicate): + pass + + +class Any(SubqueryPredicate): + pass + + +class Exists(SubqueryPredicate): + pass + + +# Commands to interact with the databases or engines +# These expressions don't truly parse the expression and consume +# whatever exists as a string until the end or a semicolon +class Command(Expression): + arg_types = {"this": True, "expression": False} + + +# Binary Expressions +# (ADD a b) +# (FROM table selects) +class Binary(Expression): + arg_types = {"this": True, "expression": True} + + @property + def left(self): + return self.this + + @property + def right(self): + return self.expression + + +class Add(Binary): + pass + + +class Connector(Binary, Condition): + pass + + +class And(Connector): + pass + + +class Or(Connector): + pass + + +class BitwiseAnd(Binary): + pass + + +class BitwiseLeftShift(Binary): + pass + + +class BitwiseOr(Binary): + pass + + +class BitwiseRightShift(Binary): + pass + + +class BitwiseXor(Binary): + pass + + +class Div(Binary): + pass + + +class Dot(Binary): + pass + + +class DPipe(Binary): + pass + + +class EQ(Binary, Predicate): + pass + + +class Escape(Binary): + pass + + +class GT(Binary, Predicate): + pass + + +class GTE(Binary, Predicate): + pass + + +class ILike(Binary, Predicate): + pass + + +class IntDiv(Binary): + pass + + +class Is(Binary, Predicate): + pass + + +class Like(Binary, Predicate): + pass + + +class LT(Binary, Predicate): + pass + + +class LTE(Binary, Predicate): + pass + + +class Mod(Binary): + pass + + +class Mul(Binary): + pass + + +class NEQ(Binary, Predicate): + pass + + +class Sub(Binary): + pass + + +# Unary Expressions +# (NOT a) +class Unary(Expression): + pass + + +class BitwiseNot(Unary): + pass + + +class Not(Unary, Condition): + pass + + +class Paren(Unary, Condition): + pass + + +class Neg(Unary): + pass + + +# Special Functions +class Alias(Expression): + arg_types = {"this": True, "alias": False} + + +class Aliases(Expression): + arg_types = {"this": True, "expressions": True} + + @property + def aliases(self): + return self.expressions + + +class AtTimeZone(Expression): + arg_types = {"this": True, "zone": True} + + +class Between(Predicate): + arg_types = {"this": True, "low": True, "high": True} + + +class Bracket(Condition): + arg_types = {"this": True, "expressions": True} + + +class Distinct(Expression): + arg_types = {"this": False, "on": False} + + +class In(Predicate): + arg_types = {"this": True, "expressions": False, "query": False, "unnest": False} + + +class TimeUnit(Expression): + """Automatically converts unit arg into a var.""" + + arg_types = {"unit": False} + + def __init__(self, **args): + unit = args.get("unit") + if isinstance(unit, Column): + args["unit"] = Var(this=unit.name) + elif isinstance(unit, Week): + unit.set("this", Var(this=unit.this.name)) + super().__init__(**args) + + +class Interval(TimeUnit): + arg_types = {"this": True, "unit": False} + + +class IgnoreNulls(Expression): + pass + + +# Functions +class Func(Condition): + """ + The base class for all function expressions. + + Attributes + is_var_len_args (bool): if set to True the last argument defined in + arg_types will be treated as a variable length argument and the + argument's value will be stored as a list. + _sql_names (list): determines the SQL name (1st item in the list) and + aliases (subsequent items) for this function expression. These + values are used to map this node to a name during parsing as well + as to provide the function's name during SQL string generation. By + default the SQL name is set to the expression's class name transformed + to snake case. + """ + + is_var_len_args = False + + @classmethod + def from_arg_list(cls, args): + args_num = len(args) + + all_arg_keys = list(cls.arg_types) + # If this function supports variable length argument treat the last argument as such. + non_var_len_arg_keys = ( + all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys + ) + + args_dict = {} + arg_idx = 0 + for arg_key in non_var_len_arg_keys: + if arg_idx >= args_num: + break + if args[arg_idx] is not None: + args_dict[arg_key] = args[arg_idx] + arg_idx += 1 + + if arg_idx < args_num and cls.is_var_len_args: + args_dict[all_arg_keys[-1]] = args[arg_idx:] + return cls(**args_dict) + + @classmethod + def sql_names(cls): + if cls is Func: + raise NotImplementedError( + "SQL name is only supported by concrete function implementations" + ) + if not hasattr(cls, "_sql_names"): + cls._sql_names = [camel_to_snake_case(cls.__name__)] + return cls._sql_names + + @classmethod + def sql_name(cls): + return cls.sql_names()[0] + + @classmethod + def default_parser_mappings(cls): + return {name: cls.from_arg_list for name in cls.sql_names()} + + +class AggFunc(Func): + pass + + +class Abs(Func): + pass + + +class Anonymous(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class ApproxDistinct(AggFunc): + arg_types = {"this": True, "accuracy": False} + + +class Array(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +class ArrayAgg(AggFunc): + pass + + +class ArrayAll(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayAny(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayContains(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayFilter(Func): + arg_types = {"this": True, "expression": True} + _sql_names = ["FILTER", "ARRAY_FILTER"] + + +class ArraySize(Func): + pass + + +class ArraySort(Func): + arg_types = {"this": True, "expression": False} + + +class ArraySum(Func): + pass + + +class ArrayUnionAgg(AggFunc): + pass + + +class Avg(AggFunc): + pass + + +class AnyValue(AggFunc): + pass + + +class Case(Func): + arg_types = {"this": False, "ifs": True, "default": False} + + +class Cast(Func): + arg_types = {"this": True, "to": True} + + +class TryCast(Cast): + pass + + +class Ceil(Func): + _sql_names = ["CEIL", "CEILING"] + + +class Coalesce(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class ConcatWs(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +class Count(AggFunc): + pass + + +class CurrentDate(Func): + arg_types = {"this": False} + + +class CurrentDatetime(Func): + arg_types = {"this": False} + + +class CurrentTime(Func): + arg_types = {"this": False} + + +class CurrentTimestamp(Func): + arg_types = {"this": False} + + +class DateAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class DatetimeAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class Extract(Func): + arg_types = {"this": True, "expression": True} + + +class TimestampAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class TimeAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class DateStrToDate(Func): + pass + + +class DateToDateStr(Func): + pass + + +class DateToDi(Func): + pass + + +class Day(Func): + pass + + +class DiToDate(Func): + pass + + +class Exp(Func): + pass + + +class Explode(Func): + pass + + +class Floor(Func): + pass + + +class Greatest(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class If(Func): + arg_types = {"this": True, "true": True, "false": False} + + +class IfNull(Func): + arg_types = {"this": True, "expression": False} + _sql_names = ["IFNULL", "NVL"] + + +class Initcap(Func): + pass + + +class JSONExtract(Func): + arg_types = {"this": True, "path": True} + _sql_names = ["JSON_EXTRACT"] + + +class JSONExtractScalar(JSONExtract): + _sql_names = ["JSON_EXTRACT_SCALAR"] + + +class JSONBExtract(JSONExtract): + _sql_names = ["JSONB_EXTRACT"] + + +class JSONBExtractScalar(JSONExtract): + _sql_names = ["JSONB_EXTRACT_SCALAR"] + + +class Least(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class Length(Func): + pass + + +class Levenshtein(Func): + arg_types = {"this": True, "expression": False} + + +class Ln(Func): + pass + + +class Log(Func): + arg_types = {"this": True, "expression": False} + + +class Log2(Func): + pass + + +class Log10(Func): + pass + + +class Lower(Func): + pass + + +class Map(Func): + arg_types = {"keys": True, "values": True} + + +class Max(AggFunc): + pass + + +class Min(AggFunc): + pass + + +class Month(Func): + pass + + +class Nvl2(Func): + arg_types = {"this": True, "true": True, "false": False} + + +class Posexplode(Func): + pass + + +class Pow(Func): + arg_types = {"this": True, "power": True} + _sql_names = ["POWER", "POW"] + + +class Quantile(AggFunc): + arg_types = {"this": True, "quantile": True} + + +class Reduce(Func): + arg_types = {"this": True, "initial": True, "merge": True, "finish": True} + + +class RegexpLike(Func): + arg_types = {"this": True, "expression": True} + + +class RegexpSplit(Func): + arg_types = {"this": True, "expression": True} + + +class Round(Func): + arg_types = {"this": True, "decimals": False} + + +class SafeDivide(Func): + arg_types = {"this": True, "expression": True} + + +class SetAgg(AggFunc): + pass + + +class SortArray(Func): + arg_types = {"this": True, "asc": False} + + +class Split(Func): + arg_types = {"this": True, "expression": True} + + +class Substring(Func): + arg_types = {"this": True, "start": True, "length": False} + + +class StrPosition(Func): + arg_types = {"this": True, "substr": True, "position": False} + + +class StrToDate(Func): + arg_types = {"this": True, "format": True} + + +class StrToTime(Func): + arg_types = {"this": True, "format": True} + + +class StrToUnix(Func): + arg_types = {"this": True, "format": True} + + +class Struct(Func): + arg_types = {"expressions": True} + is_var_len_args = True + + +class StructExtract(Func): + arg_types = {"this": True, "expression": True} + + +class Sum(AggFunc): + pass + + +class Sqrt(Func): + pass + + +class Stddev(AggFunc): + pass + + +class StddevPop(AggFunc): + pass + + +class StddevSamp(AggFunc): + pass + + +class TimeToStr(Func): + arg_types = {"this": True, "format": True} + + +class TimeToTimeStr(Func): + pass + + +class TimeToUnix(Func): + pass + + +class TimeStrToDate(Func): + pass + + +class TimeStrToTime(Func): + pass + + +class TimeStrToUnix(Func): + pass + + +class TsOrDsAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TsOrDsToDateStr(Func): + pass + + +class TsOrDsToDate(Func): + arg_types = {"this": True, "format": False} + + +class TsOrDiToDi(Func): + pass + + +class UnixToStr(Func): + arg_types = {"this": True, "format": True} + + +class UnixToTime(Func): + arg_types = {"this": True, "scale": False} + + SECONDS = Literal.string("seconds") + MILLIS = Literal.string("millis") + MICROS = Literal.string("micros") + + +class UnixToTimeStr(Func): + pass + + +class Upper(Func): + pass + + +class Variance(AggFunc): + _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"] + + +class VariancePop(AggFunc): + _sql_names = ["VARIANCE_POP", "VAR_POP"] + + +class Week(Func): + arg_types = {"this": True, "mode": False} + + +class Year(Func): + pass + + +def _norm_args(expression): + args = {} + + for k, arg in expression.args.items(): + if isinstance(arg, list): + arg = [_norm_arg(a) for a in arg] + else: + arg = _norm_arg(arg) + + if arg is not None: + args[k] = arg + + return args + + +def _norm_arg(arg): + return arg.lower() if isinstance(arg, str) else arg + + +def _all_functions(): + return [ + obj + for _, obj in inspect.getmembers( + sys.modules[__name__], + lambda obj: inspect.isclass(obj) + and issubclass(obj, Func) + and obj not in (AggFunc, Anonymous, Func), + ) + ] + + +ALL_FUNCTIONS = _all_functions() + + +def maybe_parse( + sql_or_expression, + *, + into=None, + dialect=None, + prefix=None, + **opts, +): + """Gracefully handle a possible string or expression. + + Example: + >>> maybe_parse("1") + (LITERAL this: 1, is_string: False) + >>> maybe_parse(to_identifier("x")) + (IDENTIFIER this: x, quoted: False) + + Args: + sql_or_expression (str or Expression): the SQL code string or an expression + into (Expression): the SQLGlot Expression to parse into + dialect (str): the dialect used to parse the input expressions (in the case that an + input expression is a SQL string). + prefix (str): a string to prefix the sql with before it gets parsed + (automatically includes a space) + **opts: other options to use to parse the input expressions (again, in the case + that an input expression is a SQL string). + + Returns: + Expression: the parsed or given expression. + """ + if isinstance(sql_or_expression, Expression): + return sql_or_expression + + import sqlglot + + sql = str(sql_or_expression) + if prefix: + sql = f"{prefix} {sql}" + return sqlglot.parse_one(sql, read=dialect, into=into, **opts) + + +def _maybe_copy(instance, copy=True): + return instance.copy() if copy else instance + + +def _is_wrong_expression(expression, into): + return isinstance(expression, Expression) and not isinstance(expression, into) + + +def _apply_builder( + expression, + instance, + arg, + copy=True, + prefix=None, + into=None, + dialect=None, + **opts, +): + if _is_wrong_expression(expression, into): + expression = into(this=expression) + instance = _maybe_copy(instance, copy) + expression = maybe_parse( + sql_or_expression=expression, + prefix=prefix, + into=into, + dialect=dialect, + **opts, + ) + instance.set(arg, expression) + return instance + + +def _apply_child_list_builder( + *expressions, + instance, + arg, + append=True, + copy=True, + prefix=None, + into=None, + dialect=None, + properties=None, + **opts, +): + instance = _maybe_copy(instance, copy) + parsed = [] + for expression in expressions: + if _is_wrong_expression(expression, into): + expression = into(expressions=[expression]) + expression = maybe_parse( + expression, + into=into, + dialect=dialect, + prefix=prefix, + **opts, + ) + parsed.extend(expression.expressions) + + existing = instance.args.get(arg) + if append and existing: + parsed = existing.expressions + parsed + + child = into(expressions=parsed) + for k, v in (properties or {}).items(): + child.set(k, v) + instance.set(arg, child) + return instance + + +def _apply_list_builder( + *expressions, + instance, + arg, + append=True, + copy=True, + prefix=None, + into=None, + dialect=None, + **opts, +): + inst = _maybe_copy(instance, copy) + + expressions = [ + maybe_parse( + sql_or_expression=expression, + into=into, + prefix=prefix, + dialect=dialect, + **opts, + ) + for expression in expressions + ] + + existing_expressions = inst.args.get(arg) + if append and existing_expressions: + expressions = existing_expressions + expressions + + inst.set(arg, expressions) + return inst + + +def _apply_conjunction_builder( + *expressions, + instance, + arg, + into=None, + append=True, + copy=True, + dialect=None, + **opts, +): + expressions = [exp for exp in expressions if exp is not None and exp != ""] + if not expressions: + return instance + + inst = _maybe_copy(instance, copy) + + existing = inst.args.get(arg) + if append and existing is not None: + expressions = [existing.this if into else existing] + list(expressions) + + node = and_(*expressions, dialect=dialect, **opts) + + inst.set(arg, into(this=node) if into else node) + return inst + + +def _combine(expressions, operator, dialect=None, **opts): + expressions = [ + condition(expression, dialect=dialect, **opts) for expression in expressions + ] + this = expressions[0] + if expressions[1:]: + this = _wrap_operator(this) + for expression in expressions[1:]: + this = operator(this=this, expression=_wrap_operator(expression)) + return this + + +def _wrap_operator(expression): + if isinstance(expression, (And, Or, Not)): + expression = Paren(this=expression) + return expression + + +def select(*expressions, dialect=None, **opts): + """ + Initializes a syntax tree from one or multiple SELECT expressions. + + Example: + >>> select("col1", "col2").from_("tbl").sql() + 'SELECT col1, col2 FROM tbl' + + Args: + *expressions (str or Expression): the SQL code string to parse as the expressions of a + SELECT statement. If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expressions (in the case that an + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that an input expression is a SQL string). + + Returns: + Select: the syntax tree for the SELECT statement. + """ + return Select().select(*expressions, dialect=dialect, **opts) + + +def from_(*expressions, dialect=None, **opts): + """ + Initializes a syntax tree from a FROM expression. + + Example: + >>> from_("tbl").select("col1", "col2").sql() + 'SELECT col1, col2 FROM tbl' + + Args: + *expressions (str or Expression): the SQL code string to parse as the FROM expressions of a + SELECT statement. If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression (in the case that the + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that the input expression is a SQL string). + + Returns: + Select: the syntax tree for the SELECT statement. + """ + return Select().from_(*expressions, dialect=dialect, **opts) + + +def condition(expression, dialect=None, **opts): + """ + Initialize a logical condition expression. + + Example: + >>> condition("x=1").sql() + 'x = 1' + + This is helpful for composing larger logical syntax trees: + >>> where = condition("x=1") + >>> where = where.and_("y=1") + >>> Select().from_("tbl").select("*").where(where).sql() + 'SELECT * FROM tbl WHERE x = 1 AND y = 1' + + Args: + *expression (str or Expression): the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression (in the case that the + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that the input expression is a SQL string). + + Returns: + Condition: the expression + """ + return maybe_parse( + expression, + into=Condition, + dialect=dialect, + **opts, + ) + + +def and_(*expressions, dialect=None, **opts): + """ + Combine multiple conditions with an AND logical operator. + + Example: + >>> and_("x=1", and_("y=1", "z=1")).sql() + 'x = 1 AND (y = 1 AND z = 1)' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + And: the new condition + """ + return _combine(expressions, And, dialect, **opts) + + +def or_(*expressions, dialect=None, **opts): + """ + Combine multiple conditions with an OR logical operator. + + Example: + >>> or_("x=1", or_("y=1", "z=1")).sql() + 'x = 1 OR (y = 1 OR z = 1)' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Or: the new condition + """ + return _combine(expressions, Or, dialect, **opts) + + +def not_(expression, dialect=None, **opts): + """ + Wrap a condition with a NOT operator. + + Example: + >>> not_("this_suit='black'").sql() + "NOT this_suit = 'black'" + + Args: + expression (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Not: the new condition + """ + this = condition( + expression, + dialect=dialect, + **opts, + ) + return Not(this=_wrap_operator(this)) + + +def paren(expression): + return Paren(this=expression) + + +SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") + + +def to_identifier(alias, quoted=None): + if alias is None: + return None + if isinstance(alias, Identifier): + identifier = alias + elif isinstance(alias, str): + if quoted is None: + quoted = not re.match(SAFE_IDENTIFIER_RE, alias) + identifier = Identifier(this=alias, quoted=quoted) + else: + raise ValueError( + f"Alias needs to be a string or an Identifier, got: {alias.__class__}" + ) + return identifier + + +def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): + """ + Create an Alias expression. + Expample: + >>> alias_('foo', 'bar').sql() + 'foo AS bar' + + Args: + expression (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + alias (str or Identifier): the alias name to use. If the name has + special characters it is quoted. + table (boolean): create a table alias, default false + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Alias: the aliased expression + """ + exp = maybe_parse(expression, dialect=dialect, **opts) + alias = to_identifier(alias, quoted=quoted) + alias = TableAlias(this=alias) if table else alias + + if "alias" in exp.arg_types: + exp = exp.copy() + exp.set("alias", alias) + return exp + return Alias(this=exp, alias=alias) + + +def subquery(expression, alias=None, dialect=None, **opts): + """ + Build a subquery expression. + Expample: + >>> subquery('select x from tbl', 'bar').select('x').sql() + 'SELECT x FROM (SELECT x FROM tbl) AS bar' + + Args: + expression (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + alias (str or Expression): the alias name to use. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Select: a new select with the subquery expression included + """ + + expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias) + return Select().from_(expression, dialect=dialect, **opts) + + +def column(col, table=None, quoted=None): + """ + Build a Column. + Args: + col (str or Expression): column name + table (str or Expression): table name + Returns: + Column: column instance + """ + return Column( + this=to_identifier(col, quoted=quoted), + table=to_identifier(table, quoted=quoted), + ) + + +def table_(table, db=None, catalog=None, quoted=None): + """ + Build a Table. + Args: + table (str or Expression): column name + db (str or Expression): db name + catalog (str or Expression): catalog name + Returns: + Table: table instance + """ + return Table( + this=to_identifier(table, quoted=quoted), + db=to_identifier(db, quoted=quoted), + catalog=to_identifier(catalog, quoted=quoted), + ) + + +def replace_children(expression, fun): + """ + Replace children of an expression with the result of a lambda fun(child) -> exp. + """ + for k, v in expression.args.items(): + is_list_arg = isinstance(v, list) + + child_nodes = v if is_list_arg else [v] + new_child_nodes = [] + + for cn in child_nodes: + if isinstance(cn, Expression): + cns = ensure_list(fun(cn)) + for child_node in cns: + new_child_nodes.append(child_node) + child_node.parent = expression + child_node.arg_key = k + else: + new_child_nodes.append(cn) + + expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0] + + +def column_table_names(expression): + """ + Return all table names referenced through columns in an expression. + + Example: + >>> import sqlglot + >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")) + ['c', 'a'] + + Args: + expression (sqlglot.Expression): expression to find table names + + Returns: + list: A list of unique names + """ + return list(dict.fromkeys(column.table for column in expression.find_all(Column))) + + +TRUE = Boolean(this=True) +FALSE = Boolean(this=False) +NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py new file mode 100644 index 0000000..793cff0 --- /dev/null +++ b/sqlglot/generator.py @@ -0,0 +1,1124 @@ +import logging + +from sqlglot import exp +from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors +from sqlglot.helper import apply_index_offset, csv, ensure_list +from sqlglot.time import format_time +from sqlglot.tokens import TokenType + +logger = logging.getLogger("sqlglot") + + +class Generator: + """ + Generator interprets the given syntax tree and produces a SQL string as an output. + + Args + time_mapping (dict): the dictionary of custom time mappings in which the key + represents a python time format and the output the target time format + time_trie (trie): a trie of the time_mapping keys + pretty (bool): if set to True the returned string will be formatted. Default: False. + quote_start (str): specifies which starting character to use to delimit quotes. Default: '. + quote_end (str): specifies which ending character to use to delimit quotes. Default: '. + identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". + identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". + identify (bool): if set to True all identifiers will be delimited by the corresponding + character. + normalize (bool): if set to True all identifiers will lower cased + escape (str): specifies an escape character. Default: '. + pad (int): determines padding in a formatted string. Default: 2. + indent (int): determines the size of indentation in a formatted string. Default: 4. + unnest_column_only (bool): if true unnest table aliases are considered only as column aliases + normalize_functions (str): normalize function names, "upper", "lower", or None + Default: "upper" + alias_post_tablesample (bool): if the table alias comes after tablesample + Default: False + unsupported_level (ErrorLevel): determines the generator's behavior when it encounters + unsupported expressions. Default ErrorLevel.WARN. + null_ordering (str): Indicates the default null ordering method to use if not explicitly set. + Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". + Default: "nulls_are_small" + max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. + This is only relevant if unsupported_level is ErrorLevel.RAISE. + Default: 3 + """ + + TRANSFORMS = { + exp.AnonymousProperty: lambda self, e: self.property_sql(e), + exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", + exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}", + exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}", + exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}", + exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}", + exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}", + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}", + exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + } + + NULL_ORDERING_SUPPORTED = True + + TYPE_MAPPING = { + exp.DataType.Type.NCHAR: "CHAR", + exp.DataType.Type.NVARCHAR: "VARCHAR", + } + + TOKEN_MAPPING = {} + + STRUCT_DELIMITER = ("<", ">") + + ROOT_PROPERTIES = [ + exp.AutoIncrementProperty, + exp.CharacterSetProperty, + exp.CollateProperty, + exp.EngineProperty, + exp.SchemaCommentProperty, + ] + WITH_PROPERTIES = [ + exp.AnonymousProperty, + exp.FileFormatProperty, + exp.PartitionedByProperty, + exp.TableFormatProperty, + ] + + __slots__ = ( + "time_mapping", + "time_trie", + "pretty", + "configured_pretty", + "quote_start", + "quote_end", + "identifier_start", + "identifier_end", + "identify", + "normalize", + "escape", + "pad", + "index_offset", + "unnest_column_only", + "alias_post_tablesample", + "normalize_functions", + "unsupported_level", + "unsupported_messages", + "null_ordering", + "max_unsupported", + "_indent", + "_replace_backslash", + "_escaped_quote_end", + ) + + def __init__( + self, + time_mapping=None, + time_trie=None, + pretty=None, + quote_start=None, + quote_end=None, + identifier_start=None, + identifier_end=None, + identify=False, + normalize=False, + escape=None, + pad=2, + indent=2, + index_offset=0, + unnest_column_only=False, + alias_post_tablesample=False, + normalize_functions="upper", + unsupported_level=ErrorLevel.WARN, + null_ordering=None, + max_unsupported=3, + ): + import sqlglot + + self.time_mapping = time_mapping or {} + self.time_trie = time_trie + self.pretty = pretty if pretty is not None else sqlglot.pretty + self.configured_pretty = self.pretty + self.quote_start = quote_start or "'" + self.quote_end = quote_end or "'" + self.identifier_start = identifier_start or '"' + self.identifier_end = identifier_end or '"' + self.identify = identify + self.normalize = normalize + self.escape = escape or "'" + self.pad = pad + self.index_offset = index_offset + self.unnest_column_only = unnest_column_only + self.alias_post_tablesample = alias_post_tablesample + self.normalize_functions = normalize_functions + self.unsupported_level = unsupported_level + self.unsupported_messages = [] + self.max_unsupported = max_unsupported + self.null_ordering = null_ordering + self._indent = indent + self._replace_backslash = self.escape == "\\" + self._escaped_quote_end = self.escape + self.quote_end + + def generate(self, expression): + """ + Generates a SQL string by interpreting the given syntax tree. + + Args + expression (Expression): the syntax tree. + + Returns + the SQL string. + """ + self.unsupported_messages = [] + sql = self.sql(expression).strip() + + if self.unsupported_level == ErrorLevel.IGNORE: + return sql + + if self.unsupported_level == ErrorLevel.WARN: + for msg in self.unsupported_messages: + logger.warning(msg) + elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: + raise UnsupportedError( + concat_errors(self.unsupported_messages, self.max_unsupported) + ) + + return sql + + def unsupported(self, message): + if self.unsupported_level == ErrorLevel.IMMEDIATE: + raise UnsupportedError(message) + self.unsupported_messages.append(message) + + def sep(self, sep=" "): + return f"{sep.strip()}\n" if self.pretty else sep + + def seg(self, sql, sep=" "): + return f"{self.sep(sep)}{sql}" + + def wrap(self, expression): + this_sql = self.indent( + self.sql(expression) + if isinstance(expression, (exp.Select, exp.Union)) + else self.sql(expression, "this"), + level=1, + pad=0, + ) + return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" + + def no_identify(self, func): + original = self.identify + self.identify = False + result = func() + self.identify = original + return result + + def normalize_func(self, name): + if self.normalize_functions == "upper": + return name.upper() + if self.normalize_functions == "lower": + return name.lower() + return name + + def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False): + if not self.pretty: + return sql + + pad = self.pad if pad is None else pad + lines = sql.split("\n") + + return "\n".join( + line + if (skip_first and i == 0) or (skip_last and i == len(lines) - 1) + else f"{' ' * (level * self._indent + pad)}{line}" + for i, line in enumerate(lines) + ) + + def sql(self, expression, key=None): + if not expression: + return "" + + if isinstance(expression, str): + return expression + + if key: + return self.sql(expression.args.get(key)) + + transform = self.TRANSFORMS.get(expression.__class__) + + if callable(transform): + return transform(self, expression) + if transform: + return transform + + if not isinstance(expression, exp.Expression): + raise ValueError( + f"Expected an Expression. Received {type(expression)}: {expression}" + ) + + exp_handler_name = f"{expression.key}_sql" + if hasattr(self, exp_handler_name): + return getattr(self, exp_handler_name)(expression) + + if isinstance(expression, exp.Func): + return self.function_fallback_sql(expression) + + raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") + + def annotation_sql(self, expression): + return self.sql(expression, "expression") + + def uncache_sql(self, expression): + table = self.sql(expression, "this") + exists_sql = " IF EXISTS" if expression.args.get("exists") else "" + return f"UNCACHE TABLE{exists_sql} {table}" + + def cache_sql(self, expression): + lazy = " LAZY" if expression.args.get("lazy") else "" + table = self.sql(expression, "this") + options = expression.args.get("options") + options = ( + f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" + if options + else "" + ) + sql = self.sql(expression, "expression") + sql = f" AS{self.sep()}{sql}" if sql else "" + sql = f"CACHE{lazy} TABLE {table}{options}{sql}" + return self.prepend_ctes(expression, sql) + + def characterset_sql(self, expression): + if isinstance(expression.parent, exp.Cast): + return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" + default = "DEFAULT " if expression.args.get("default") else "" + return f"{default}CHARACTER SET={self.sql(expression, 'this')}" + + def column_sql(self, expression): + return ".".join( + part + for part in [ + self.sql(expression, "db"), + self.sql(expression, "table"), + self.sql(expression, "this"), + ] + if part + ) + + def columndef_sql(self, expression): + column = self.sql(expression, "this") + kind = self.sql(expression, "kind") + constraints = self.expressions( + expression, key="constraints", sep=" ", flat=True + ) + + if not constraints: + return f"{column} {kind}" + return f"{column} {kind} {constraints}" + + def columnconstraint_sql(self, expression): + this = self.sql(expression, "this") + kind_sql = self.sql(expression, "kind") + return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql + + def autoincrementcolumnconstraint_sql(self, _): + return self.token_sql(TokenType.AUTO_INCREMENT) + + def checkcolumnconstraint_sql(self, expression): + this = self.sql(expression, "this") + return f"CHECK ({this})" + + def commentcolumnconstraint_sql(self, expression): + comment = self.sql(expression, "this") + return f"COMMENT {comment}" + + def collatecolumnconstraint_sql(self, expression): + collate = self.sql(expression, "this") + return f"COLLATE {collate}" + + def defaultcolumnconstraint_sql(self, expression): + default = self.sql(expression, "this") + return f"DEFAULT {default}" + + def notnullcolumnconstraint_sql(self, _): + return "NOT NULL" + + def primarykeycolumnconstraint_sql(self, _): + return "PRIMARY KEY" + + def uniquecolumnconstraint_sql(self, _): + return "UNIQUE" + + def create_sql(self, expression): + this = self.sql(expression, "this") + kind = self.sql(expression, "kind").upper() + expression_sql = self.sql(expression, "expression") + expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + replace = " OR REPLACE" if expression.args.get("replace") else "" + exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" + unique = " UNIQUE" if expression.args.get("unique") else "" + properties = self.sql(expression, "properties") + + expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}" + return self.prepend_ctes(expression, expression_sql) + + def prepend_ctes(self, expression, sql): + with_ = self.sql(expression, "with") + if with_: + sql = f"{with_}{self.sep()}{sql}" + return sql + + def with_sql(self, expression): + sql = self.expressions(expression, flat=True) + recursive = "RECURSIVE " if expression.args.get("recursive") else "" + + return f"WITH {recursive}{sql}" + + def cte_sql(self, expression): + alias = self.sql(expression, "alias") + return f"{alias} AS {self.wrap(expression)}" + + def tablealias_sql(self, expression): + alias = self.sql(expression, "this") + columns = self.expressions(expression, key="columns", flat=True) + columns = f"({columns})" if columns else "" + return f"{alias}{columns}" + + def bitstring_sql(self, expression): + return f"b'{self.sql(expression, 'this')}'" + + def datatype_sql(self, expression): + type_value = expression.this + type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) + nested = "" + interior = self.expressions(expression, flat=True) + if interior: + nested = ( + f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" + if expression.args.get("nested") + else f"({interior})" + ) + return f"{type_sql}{nested}" + + def delete_sql(self, expression): + this = self.sql(expression, "this") + where_sql = self.sql(expression, "where") + sql = f"DELETE FROM {this}{where_sql}" + return self.prepend_ctes(expression, sql) + + def drop_sql(self, expression): + this = self.sql(expression, "this") + kind = expression.args["kind"] + exists_sql = " IF EXISTS " if expression.args.get("exists") else " " + return f"DROP {kind}{exists_sql}{this}" + + def except_sql(self, expression): + return self.prepend_ctes( + expression, + self.set_operation(expression, self.except_op(expression)), + ) + + def except_op(self, expression): + return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}" + + def fetch_sql(self, expression): + direction = expression.args.get("direction") + direction = f" {direction.upper()}" if direction else "" + count = expression.args.get("count") + count = f" {count}" if count else "" + return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" + + def filter_sql(self, expression): + this = self.sql(expression, "this") + where = self.sql(expression, "expression")[1:] # where has a leading space + return f"{this} FILTER({where})" + + def hint_sql(self, expression): + if self.sql(expression, "this"): + self.unsupported("Hints are not supported") + return "" + + def index_sql(self, expression): + this = self.sql(expression, "this") + table = self.sql(expression, "table") + columns = self.sql(expression, "columns") + return f"{this} ON {table} {columns}" + + def identifier_sql(self, expression): + value = expression.name + value = value.lower() if self.normalize else value + if expression.args.get("quoted") or self.identify: + return f"{self.identifier_start}{value}{self.identifier_end}" + return value + + def partition_sql(self, expression): + keys = csv( + *[ + f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] + for k, v in expression.args.get("this") + ] + ) + return f"PARTITION({keys})" + + def properties_sql(self, expression): + root_properties = [] + with_properties = [] + + for p in expression.expressions: + p_class = p.__class__ + if p_class in self.ROOT_PROPERTIES: + root_properties.append(p) + elif p_class in self.WITH_PROPERTIES: + with_properties.append(p) + + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) + + def root_properties(self, properties): + if properties.expressions: + return self.sep() + self.expressions( + properties, + indent=False, + sep=" ", + ) + return "" + + def properties(self, properties, prefix="", sep=", "): + if properties.expressions: + expressions = self.expressions( + properties, + sep=sep, + indent=False, + ) + return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" + return "" + + def with_properties(self, properties): + return self.properties( + properties, + prefix="WITH", + ) + + def property_sql(self, expression): + key = expression.name + value = self.sql(expression, "value") + return f"{key} = {value}" + + def insert_sql(self, expression): + kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" + this = self.sql(expression, "this") + exists = " IF EXISTS " if expression.args.get("exists") else " " + partition_sql = ( + self.sql(expression, "partition") + if expression.args.get("partition") + else "" + ) + expression_sql = self.sql(expression, "expression") + sep = self.sep() if partition_sql else "" + sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" + return self.prepend_ctes(expression, sql) + + def intersect_sql(self, expression): + return self.prepend_ctes( + expression, + self.set_operation(expression, self.intersect_op(expression)), + ) + + def intersect_op(self, expression): + return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}" + + def introducer_sql(self, expression): + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + + def table_sql(self, expression): + return ".".join( + part + for part in [ + self.sql(expression, "catalog"), + self.sql(expression, "db"), + self.sql(expression, "this"), + ] + if part + ) + + def tablesample_sql(self, expression): + if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): + this = self.sql(expression.this, "this") + alias = f" AS {self.sql(expression.this, 'alias')}" + else: + this = self.sql(expression, "this") + alias = "" + method = self.sql(expression, "method") + method = f" {method.upper()} " if method else "" + numerator = self.sql(expression, "bucket_numerator") + denominator = self.sql(expression, "bucket_denominator") + field = self.sql(expression, "bucket_field") + field = f" ON {field}" if field else "" + bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else "" + percent = self.sql(expression, "percent") + percent = f"{percent} PERCENT" if percent else "" + rows = self.sql(expression, "rows") + rows = f"{rows} ROWS" if rows else "" + size = self.sql(expression, "size") + return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}" + + def tuple_sql(self, expression): + return f"({self.expressions(expression, flat=True)})" + + def update_sql(self, expression): + this = self.sql(expression, "this") + set_sql = self.expressions(expression, flat=True) + from_sql = self.sql(expression, "from") + where_sql = self.sql(expression, "where") + sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}" + return self.prepend_ctes(expression, sql) + + def values_sql(self, expression): + return f"VALUES{self.seg('')}{self.expressions(expression)}" + + def var_sql(self, expression): + return self.sql(expression, "this") + + def from_sql(self, expression): + expressions = self.expressions(expression, flat=True) + return f"{self.seg('FROM')} {expressions}" + + def group_sql(self, expression): + group_by = self.op_expressions("GROUP BY", expression) + grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) + grouping_sets = ( + f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" + if grouping_sets + else "" + ) + cube = self.expressions(expression, key="cube", indent=False) + cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" + rollup = self.expressions(expression, key="rollup", indent=False) + rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" + return f"{group_by}{grouping_sets}{cube}{rollup}" + + def having_sql(self, expression): + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('HAVING')}{self.sep()}{this}" + + def join_sql(self, expression): + op_sql = self.seg( + " ".join(op for op in (expression.side, expression.kind, "JOIN") if op) + ) + on_sql = self.sql(expression, "on") + using = expression.args.get("using") + + if not on_sql and using: + on_sql = csv(*(self.sql(column) for column in using)) + + if on_sql: + on_sql = self.indent(on_sql, skip_first=True) + space = self.seg(" " * self.pad) if self.pretty else " " + if using: + on_sql = f"{space}USING ({on_sql})" + else: + on_sql = f"{space}ON {on_sql}" + + expression_sql = self.sql(expression, "expression") + this_sql = self.sql(expression, "this") + return f"{expression_sql}{op_sql} {this_sql}{on_sql}" + + def lambda_sql(self, expression): + args = self.expressions(expression, flat=True) + args = f"({args})" if len(args.split(",")) > 1 else args + return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}") + + def lateral_sql(self, expression): + this = self.sql(expression, "this") + op_sql = self.seg( + f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" + ) + alias = expression.args["alias"] + table = alias.name + table = f" {table}" if table else table + columns = self.expressions(alias, key="columns", flat=True) + columns = f" AS {columns}" if columns else "" + return f"{op_sql}{self.sep()}{this}{table}{columns}" + + def limit_sql(self, expression): + this = self.sql(expression, "this") + return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" + + def offset_sql(self, expression): + this = self.sql(expression, "this") + return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + + def literal_sql(self, expression): + text = expression.this or "" + if expression.is_string: + if self._replace_backslash: + text = text.replace("\\", "\\\\") + text = text.replace(self.quote_end, self._escaped_quote_end) + return f"{self.quote_start}{text}{self.quote_end}" + return text + + def null_sql(self, *_): + return "NULL" + + def boolean_sql(self, expression): + return "TRUE" if expression.this else "FALSE" + + def order_sql(self, expression, flat=False): + this = self.sql(expression, "this") + this = f"{this} " if this else this + return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) + + def cluster_sql(self, expression): + return self.op_expressions("CLUSTER BY", expression) + + def distribute_sql(self, expression): + return self.op_expressions("DISTRIBUTE BY", expression) + + def sort_sql(self, expression): + return self.op_expressions("SORT BY", expression) + + def ordered_sql(self, expression): + desc = expression.args.get("desc") + asc = not desc + nulls_first = expression.args.get("nulls_first") + nulls_last = not nulls_first + nulls_are_large = self.null_ordering == "nulls_are_large" + nulls_are_small = self.null_ordering == "nulls_are_small" + nulls_are_last = self.null_ordering == "nulls_are_last" + + sort_order = " DESC" if desc else "" + nulls_sort_change = "" + if nulls_first and ( + (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last + ): + nulls_sort_change = " NULLS FIRST" + elif ( + nulls_last + and ((asc and nulls_are_small) or (desc and nulls_are_large)) + and not nulls_are_last + ): + nulls_sort_change = " NULLS LAST" + + if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: + self.unsupported( + "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" + ) + nulls_sort_change = "" + + return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" + + def query_modifiers(self, expression, *sqls): + return csv( + *sqls, + *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins", [])], + self.sql(expression, "where"), + self.sql(expression, "group"), + self.sql(expression, "having"), + self.sql(expression, "qualify"), + self.sql(expression, "window"), + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + self.sql(expression, "order"), + self.sql(expression, "limit"), + self.sql(expression, "offset"), + sep="", + ) + + def select_sql(self, expression): + hint = self.sql(expression, "hint") + distinct = self.sql(expression, "distinct") + distinct = f" {distinct}" if distinct else "" + expressions = self.expressions(expression) + expressions = f"{self.sep()}{expressions}" if expressions else expressions + sql = self.query_modifiers( + expression, + f"SELECT{hint}{distinct}{expressions}", + self.sql(expression, "from"), + ) + return self.prepend_ctes(expression, sql) + + def schema_sql(self, expression): + this = self.sql(expression, "this") + this = f"{this} " if this else "" + sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + return f"{this}{sql}" + + def star_sql(self, expression): + except_ = self.expressions(expression, key="except", flat=True) + except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" + replace = self.expressions(expression, key="replace", flat=True) + replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" + return f"*{except_}{replace}" + + def structkwarg_sql(self, expression): + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + + def placeholder_sql(self, *_): + return "?" + + def subquery_sql(self, expression): + alias = self.sql(expression, "alias") + + return self.query_modifiers( + expression, + self.wrap(expression), + f" AS {alias}" if alias else "", + ) + + def qualify_sql(self, expression): + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('QUALIFY')}{self.sep()}{this}" + + def union_sql(self, expression): + return self.prepend_ctes( + expression, + self.set_operation(expression, self.union_op(expression)), + ) + + def union_op(self, expression): + return f"UNION{'' if expression.args.get('distinct') else ' ALL'}" + + def unnest_sql(self, expression): + args = self.expressions(expression, flat=True) + alias = expression.args.get("alias") + if alias and self.unnest_column_only: + columns = alias.columns + alias = self.sql(columns[0]) if columns else "" + else: + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else alias + ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" + return f"UNNEST({args}){ordinality}{alias}" + + def where_sql(self, expression): + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('WHERE')}{self.sep()}{this}" + + def window_sql(self, expression): + this = self.sql(expression, "this") + partition = self.expressions(expression, key="partition_by", flat=True) + partition = f"PARTITION BY {partition}" if partition else "" + order = expression.args.get("order") + order_sql = self.order_sql(order, flat=True) if order else "" + partition_sql = partition + " " if partition and order else partition + spec = expression.args.get("spec") + spec_sql = " " + self.window_spec_sql(spec) if spec else "" + alias = self.sql(expression, "alias") + if expression.arg_key == "window": + this = this = f"{self.seg('WINDOW')} {this} AS" + else: + this = f"{this} OVER" + + if not partition and not order and not spec and alias: + return f"{this} {alias}" + + return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" + + def window_spec_sql(self, expression): + kind = self.sql(expression, "kind") + start = csv( + self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " + ) + end = ( + csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") + or "CURRENT ROW" + ) + return f"{kind} BETWEEN {start} AND {end}" + + def withingroup_sql(self, expression): + this = self.sql(expression, "this") + expression = self.sql(expression, "expression")[1:] # order has a leading space + return f"{this} WITHIN GROUP ({expression})" + + def between_sql(self, expression): + this = self.sql(expression, "this") + low = self.sql(expression, "low") + high = self.sql(expression, "high") + return f"{this} BETWEEN {low} AND {high}" + + def bracket_sql(self, expression): + expressions = apply_index_offset(expression.expressions, self.index_offset) + expressions = ", ".join(self.sql(e) for e in expressions) + + return f"{self.sql(expression, 'this')}[{expressions}]" + + def all_sql(self, expression): + return f"ALL {self.wrap(expression)}" + + def any_sql(self, expression): + return f"ANY {self.wrap(expression)}" + + def exists_sql(self, expression): + return f"EXISTS{self.wrap(expression)}" + + def case_sql(self, expression): + this = self.indent(self.sql(expression, "this"), skip_first=True) + this = f" {this}" if this else "" + ifs = [] + + for e in expression.args["ifs"]: + ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}")) + ifs.append(self.indent(f"THEN {self.sql(e, 'true')}")) + + if expression.args.get("default") is not None: + ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}")) + + ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs) + statement = f"CASE{this}{ifs}{self.seg('END')}" + return statement + + def constraint_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"CONSTRAINT {this} {expressions}" + + def extract_sql(self, expression): + this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + return f"EXTRACT({this} FROM {expression_sql})" + + def check_sql(self, expression): + this = self.sql(expression, key="this") + return f"CHECK ({this})" + + def foreignkey_sql(self, expression): + expressions = self.expressions(expression, flat=True) + reference = self.sql(expression, "reference") + reference = f" {reference}" if reference else "" + delete = self.sql(expression, "delete") + delete = f" ON DELETE {delete}" if delete else "" + update = self.sql(expression, "update") + update = f" ON UPDATE {update}" if update else "" + return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" + + def unique_sql(self, expression): + columns = self.expressions(expression, key="expressions") + return f"UNIQUE ({columns})" + + def if_sql(self, expression): + return self.case_sql( + exp.Case(ifs=[expression], default=expression.args.get("false")) + ) + + def in_sql(self, expression): + query = expression.args.get("query") + unnest = expression.args.get("unnest") + if query: + in_sql = self.wrap(query) + elif unnest: + in_sql = self.in_unnest_op(unnest) + else: + in_sql = f"({self.expressions(expression, flat=True)})" + return f"{self.sql(expression, 'this')} IN {in_sql}" + + def in_unnest_op(self, unnest): + return f"(SELECT {self.sql(unnest)})" + + def interval_sql(self, expression): + return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}" + + def reference_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"REFERENCES {this}({expressions})" + + def anonymous_sql(self, expression): + args = self.indent( + self.expressions(expression, flat=True), skip_first=True, skip_last=True + ) + return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" + + def paren_sql(self, expression): + if isinstance(expression.unnest(), exp.Select): + return self.wrap(expression) + sql = self.seg(self.indent(self.sql(expression, "this")), sep="") + return f"({sql}{self.seg(')', sep='')}" + + def neg_sql(self, expression): + return f"-{self.sql(expression, 'this')}" + + def not_sql(self, expression): + return f"NOT {self.sql(expression, 'this')}" + + def alias_sql(self, expression): + to_sql = self.sql(expression, "alias") + to_sql = f" AS {to_sql}" if to_sql else "" + return f"{self.sql(expression, 'this')}{to_sql}" + + def aliases_sql(self, expression): + return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" + + def attimezone_sql(self, expression): + this = self.sql(expression, "this") + zone = self.sql(expression, "zone") + return f"{this} AT TIME ZONE {zone}" + + def add_sql(self, expression): + return self.binary(expression, "+") + + def and_sql(self, expression): + return self.connector_sql(expression, "AND") + + def connector_sql(self, expression, op): + if not self.pretty: + return self.binary(expression, op) + + return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False)) + + def bitwiseand_sql(self, expression): + return self.binary(expression, "&") + + def bitwiseleftshift_sql(self, expression): + return self.binary(expression, "<<") + + def bitwisenot_sql(self, expression): + return f"~{self.sql(expression, 'this')}" + + def bitwiseor_sql(self, expression): + return self.binary(expression, "|") + + def bitwiserightshift_sql(self, expression): + return self.binary(expression, ">>") + + def bitwisexor_sql(self, expression): + return self.binary(expression, "^") + + def cast_sql(self, expression): + return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + + def currentdate_sql(self, expression): + zone = self.sql(expression, "this") + return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" + + def command_sql(self, expression): + return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" + + def distinct_sql(self, expression): + this = self.sql(expression, "this") + this = f" {this}" if this else "" + + on = self.sql(expression, "on") + on = f" ON {on}" if on else "" + return f"DISTINCT{this}{on}" + + def ignorenulls_sql(self, expression): + return f"{self.sql(expression, 'this')} IGNORE NULLS" + + def intdiv_sql(self, expression): + return self.sql( + exp.Cast( + this=exp.Div( + this=expression.args["this"], + expression=expression.args["expression"], + ), + to=exp.DataType(this=exp.DataType.Type.INT), + ) + ) + + def dpipe_sql(self, expression): + return self.binary(expression, "||") + + def div_sql(self, expression): + return self.binary(expression, "/") + + def dot_sql(self, expression): + return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" + + def eq_sql(self, expression): + return self.binary(expression, "=") + + def escape_sql(self, expression): + return self.binary(expression, "ESCAPE") + + def gt_sql(self, expression): + return self.binary(expression, ">") + + def gte_sql(self, expression): + return self.binary(expression, ">=") + + def ilike_sql(self, expression): + return self.binary(expression, "ILIKE") + + def is_sql(self, expression): + return self.binary(expression, "IS") + + def like_sql(self, expression): + return self.binary(expression, "LIKE") + + def lt_sql(self, expression): + return self.binary(expression, "<") + + def lte_sql(self, expression): + return self.binary(expression, "<=") + + def mod_sql(self, expression): + return self.binary(expression, "%") + + def mul_sql(self, expression): + return self.binary(expression, "*") + + def neq_sql(self, expression): + return self.binary(expression, "<>") + + def or_sql(self, expression): + return self.connector_sql(expression, "OR") + + def sub_sql(self, expression): + return self.binary(expression, "-") + + def trycast_sql(self, expression): + return ( + f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + ) + + def binary(self, expression, op): + return ( + f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" + ) + + def function_fallback_sql(self, expression): + args = [] + for arg_key in expression.arg_types: + arg_value = ensure_list(expression.args.get(arg_key) or []) + for a in arg_value: + args.append(self.sql(a)) + + args_str = self.indent(", ".join(args), skip_first=True, skip_last=True) + return f"{self.normalize_func(expression.sql_name())}({args_str})" + + def format_time(self, expression): + return format_time( + self.sql(expression, "format"), self.time_mapping, self.time_trie + ) + + def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): + expressions = expression.args.get(key or "expressions") + + if not expressions: + return "" + + if flat: + return sep.join(self.sql(e) for e in expressions) + + expressions = self.sep(sep).join(self.sql(e) for e in expressions) + if indent: + return self.indent(expressions, skip_first=False) + return expressions + + def op_expressions(self, op, expression, flat=False): + expressions_sql = self.expressions(expression, flat=flat) + if flat: + return f"{op} {expressions_sql}" + return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" + + def set_operation(self, expression, op): + this = self.sql(expression, "this") + op = self.seg(op) + return self.query_modifiers( + expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" + ) + + def token_sql(self, token_type): + return self.TOKEN_MAPPING.get(token_type, token_type.name) diff --git a/sqlglot/helper.py b/sqlglot/helper.py new file mode 100644 index 0000000..5d90c49 --- /dev/null +++ b/sqlglot/helper.py @@ -0,0 +1,123 @@ +import logging +import re +from contextlib import contextmanager +from enum import Enum + +CAMEL_CASE_PATTERN = re.compile("(?>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y") + >>> eliminate_subqueries(expression).sql() + 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + expression = simplify(expression) + queries = {} + + for scope in traverse_scope(expression): + query = scope.expression + queries[query] = queries.get(query, []) + [query] + + sequence = itertools.count() + + for query, duplicates in queries.items(): + if len(duplicates) == 1: + continue + + alias_ = f"_e_{next(sequence)}" + + for dup in duplicates: + parent = dup.parent + if isinstance(parent, exp.Subquery): + parent.replace(alias(table(alias_), parent.alias_or_name, table=True)) + elif isinstance(parent, exp.Union): + dup.replace(select("*").from_(alias_)) + + expression.with_(alias_, as_=query, copy=False) + + return expression diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py new file mode 100644 index 0000000..ba562df --- /dev/null +++ b/sqlglot/optimizer/expand_multi_table_selects.py @@ -0,0 +1,16 @@ +from sqlglot import exp + + +def expand_multi_table_selects(expression): + for from_ in expression.find_all(exp.From): + parent = from_.parent + + for query in from_.expressions[1:]: + parent.join( + query, + join_type="CROSS", + copy=False, + ) + from_.expressions.remove(query) + + return expression diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py new file mode 100644 index 0000000..c2e021e --- /dev/null +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -0,0 +1,31 @@ +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.scope import traverse_scope + + +def isolate_table_selects(expression): + for scope in traverse_scope(expression): + if len(scope.selected_sources) == 1: + continue + + for (_, source) in scope.selected_sources.values(): + if not isinstance(source, exp.Table): + continue + + if not isinstance(source.parent, exp.Alias): + raise OptimizeError( + "Tables require an alias. Run qualify_tables optimization." + ) + + parent = source.parent + + parent.replace( + exp.select("*") + .from_( + alias(source, source.name or parent.alias, table=True), + copy=False, + ) + .subquery(parent.alias, copy=False) + ) + + return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py new file mode 100644 index 0000000..2c9f89c --- /dev/null +++ b/sqlglot/optimizer/normalize.py @@ -0,0 +1,136 @@ +from sqlglot import exp +from sqlglot.helper import while_changing +from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort + + +def normalize(expression, dnf=False, max_distance=128): + """ + Rewrite sqlglot AST into conjunctive normal form. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(x AND y) OR z") + >>> normalize(expression).sql() + '(x OR z) AND (y OR z)' + + Args: + expression (sqlglot.Expression): expression to normalize + dnf (bool): rewrite in disjunctive normal form instead + max_distance (int): the maximal estimated distance from cnf to attempt conversion + Returns: + sqlglot.Expression: normalized expression + """ + expression = simplify(expression) + + expression = while_changing( + expression, lambda e: distributive_law(e, dnf, max_distance) + ) + return simplify(expression) + + +def normalized(expression, dnf=False): + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + + return not any( + connector.find_ancestor(ancestor) for connector in expression.find_all(root) + ) + + +def normalization_distance(expression, dnf=False): + """ + The difference in the number of predicates between the current expression and the normalized form. + + This is used as an estimate of the cost of the conversion which is exponential in complexity. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") + >>> normalization_distance(expression) + 4 + + Args: + expression (sqlglot.Expression): expression to compute distance + dnf (bool): compute to dnf distance instead + Returns: + int: difference + """ + return sum(_predicate_lengths(expression, dnf)) - ( + len(list(expression.find_all(exp.Connector))) + 1 + ) + + +def _predicate_lengths(expression, dnf): + """ + Returns a list of predicate lengths when expanded to normalized form. + + (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). + """ + expression = expression.unnest() + + if not isinstance(expression, exp.Connector): + return [1] + + left, right = expression.args.values() + + if isinstance(expression, exp.And if dnf else exp.Or): + x = [ + a + b + for a in _predicate_lengths(left, dnf) + for b in _predicate_lengths(right, dnf) + ] + return x + return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) + + +def distributive_law(expression, dnf, max_distance): + """ + x OR (y AND z) -> (x OR y) AND (x OR z) + (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) + """ + if isinstance(expression.unnest(), exp.Connector): + if normalization_distance(expression, dnf) > max_distance: + return expression + + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + + if isinstance(expression, from_exp): + a, b = expression.unnest_operands() + + from_func = exp.and_ if from_exp == exp.And else exp.or_ + to_func = exp.and_ if to_exp == exp.And else exp.or_ + + if isinstance(a, to_exp) and isinstance(b, to_exp): + if len(tuple(a.find_all(exp.Connector))) > len( + tuple(b.find_all(exp.Connector)) + ): + return _distribute(a, b, from_func, to_func) + return _distribute(b, a, from_func, to_func) + if isinstance(a, to_exp): + return _distribute(b, a, from_func, to_func) + if isinstance(b, to_exp): + return _distribute(a, b, from_func, to_func) + + return expression + + +def _distribute(a, b, from_func, to_func): + if isinstance(a, exp.Connector): + exp.replace_children( + a, + lambda c: to_func( + exp.paren(from_func(c, b.left)), + exp.paren(from_func(c, b.right)), + ), + ) + else: + a = to_func(from_func(a, b.left), from_func(a, b.right)) + + return _simplify(a) + + +def _simplify(node): + node = uniq_sort(flatten(node)) + exp.replace_children(node, _simplify) + return node diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py new file mode 100644 index 0000000..40e4ab1 --- /dev/null +++ b/sqlglot/optimizer/optimize_joins.py @@ -0,0 +1,75 @@ +from sqlglot import exp +from sqlglot.helper import tsort +from sqlglot.optimizer.simplify import simplify + + +def optimize_joins(expression): + """ + Removes cross joins if possible and reorder joins based on predicate dependencies. + """ + for select in expression.find_all(exp.Select): + references = {} + cross_joins = [] + + for join in select.args.get("joins", []): + name = join.this.alias_or_name + tables = other_table_names(join, name) + + if tables: + for table in tables: + references[table] = references.get(table, []) + [join] + else: + cross_joins.append((name, join)) + + for name, join in cross_joins: + for dep in references.get(name, []): + on = dep.args["on"] + on = on.replace(simplify(on)) + + if isinstance(on, exp.Connector): + for predicate in on.flatten(): + if name in exp.column_table_names(predicate): + predicate.replace(exp.TRUE) + join.on(predicate, copy=False) + + expression = reorder_joins(expression) + expression = normalize(expression) + return expression + + +def reorder_joins(expression): + """ + Reorder joins by topological sort order based on predicate references. + """ + for from_ in expression.find_all(exp.From): + head = from_.expressions[0] + parent = from_.parent + joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {head.alias_or_name: []} + + for name, join in joins.items(): + dag[name] = other_table_names(join, name) + + parent.set( + "joins", + [joins[name] for name in tsort(dag) if name != head.alias_or_name], + ) + return expression + + +def normalize(expression): + """ + Remove INNER and OUTER from joins as they are optional. + """ + for join in expression.find_all(exp.Join): + if join.kind != "CROSS": + join.set("kind", None) + return expression + + +def other_table_names(join, exclude): + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py new file mode 100644 index 0000000..c03fe3c --- /dev/null +++ b/sqlglot/optimizer/optimizer.py @@ -0,0 +1,43 @@ +from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries +from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects +from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.normalize import normalize +from sqlglot.optimizer.optimize_joins import optimize_joins +from sqlglot.optimizer.pushdown_predicates import pushdown_predicates +from sqlglot.optimizer.pushdown_projections import pushdown_projections +from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.quote_identities import quote_identities +from sqlglot.optimizer.unnest_subqueries import unnest_subqueries + + +def optimize(expression, schema=None, db=None, catalog=None): + """ + Rewrite a sqlglot AST into an optimized form. + + Args: + expression (sqlglot.Expression): expression to optimize + schema (dict|sqlglot.optimizer.Schema): database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + db (str): specify the default database, as might be set by a `USE DATABASE db` statement + catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement + Returns: + sqlglot.Expression: optimized expression + """ + expression = expression.copy() + expression = qualify_tables(expression, db=db, catalog=catalog) + expression = isolate_table_selects(expression) + expression = qualify_columns(expression, schema) + expression = pushdown_projections(expression) + expression = normalize(expression) + expression = unnest_subqueries(expression) + expression = expand_multi_table_selects(expression) + expression = pushdown_predicates(expression) + expression = optimize_joins(expression) + expression = eliminate_subqueries(expression) + expression = quote_identities(expression) + return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py new file mode 100644 index 0000000..e757322 --- /dev/null +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -0,0 +1,176 @@ +from sqlglot import exp +from sqlglot.optimizer.normalize import normalized +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def pushdown_predicates(expression): + """ + Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS + + Example: + >>> import sqlglot + >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_predicates(expression).sql() + 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in reversed(traverse_scope(expression)): + select = scope.expression + where = select.args.get("where") + if where: + pushdown(where.this, scope.selected_sources) + + # joins should only pushdown into itself, not to other joins + # so we limit the selected sources to only itself + for join in select.args.get("joins") or []: + name = join.this.alias_or_name + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + + return expression + + +def pushdown(condition, sources): + if not condition: + return + + condition = condition.replace(simplify(condition)) + cnf_like = normalized(condition) or not normalized(condition, dnf=True) + + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) + + if cnf_like: + pushdown_cnf(predicates, sources) + else: + pushdown_dnf(predicates, sources) + + +def pushdown_cnf(predicates, scope): + """ + If the predicates are in CNF like form, we can simply replace each block in the parent. + """ + for predicate in predicates: + for node in nodes_for_predicate(predicate, scope).values(): + if isinstance(node, exp.Join): + predicate.replace(exp.TRUE) + node.on(predicate, copy=False) + break + if isinstance(node, exp.Select): + predicate.replace(exp.TRUE) + node.where(replace_aliases(node, predicate), copy=False) + + +def pushdown_dnf(predicates, scope): + """ + If the predicates are in DNF form, we can only push down conditions that are in all blocks. + Additionally, we can't remove predicates from their original form. + """ + # find all the tables that can be pushdown too + # these are tables that are referenced in all blocks of a DNF + # (a.x AND b.x) OR (a.y AND c.y) + # only table a can be push down + pushdown_tables = set() + + for a in predicates: + a_tables = set(exp.column_table_names(a)) + + for b in predicates: + a_tables &= set(exp.column_table_names(b)) + + pushdown_tables.update(a_tables) + + conditions = {} + + # for every pushdown table, find all related conditions in all predicates + # combine them with ORS + # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) + for table in sorted(pushdown_tables): + for predicate in predicates: + nodes = nodes_for_predicate(predicate, scope) + + if table not in nodes: + continue + + predicate_condition = None + + for column in predicate.find_all(exp.Column): + if column.table == table: + condition = column.find_ancestor(exp.Condition) + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) + + if predicate_condition: + conditions[table] = ( + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition + ) + + for name, node in nodes.items(): + if name not in conditions: + continue + + predicate = conditions[name] + + if isinstance(node, exp.Join): + node.on(predicate, copy=False) + elif isinstance(node, exp.Select): + node.where(replace_aliases(node, predicate), copy=False) + + +def nodes_for_predicate(predicate, sources): + nodes = {} + tables = exp.column_table_names(predicate) + where_condition = isinstance( + predicate.find_ancestor(exp.Join, exp.Where), exp.Where + ) + + for table in tables: + node, source = sources.get(table) or (None, None) + + # if the predicate is in a where statement we can try to push it down + # we want to find the root join or from statement + if node and where_condition: + node = node.find_ancestor(exp.Join, exp.From) + + # a node can reference a CTE which should be push down + if isinstance(node, exp.From) and not isinstance(source, exp.Table): + node = source.expression + + if isinstance(node, exp.Join): + if node.side: + return {} + nodes[table] = node + elif isinstance(node, exp.Select) and len(tables) == 1: + if not node.args.get("group"): + nodes[table] = node + return nodes + + +def replace_aliases(source, predicate): + aliases = {} + + for select in source.selects: + if isinstance(select, exp.Alias): + aliases[select.alias] = select.this + else: + aliases[select.name] = select + + def _replace_alias(column): + if isinstance(column, exp.Column) and column.name in aliases: + return aliases[column.name] + return column + + return predicate.transform(_replace_alias) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py new file mode 100644 index 0000000..097ce04 --- /dev/null +++ b/sqlglot/optimizer/pushdown_projections.py @@ -0,0 +1,85 @@ +from collections import defaultdict + +from sqlglot import alias, exp +from sqlglot.optimizer.scope import Scope, traverse_scope + +# Sentinel value that means an outer query selecting ALL columns +SELECT_ALL = object() + + +def pushdown_projections(expression): + """ + Rewrite sqlglot AST to remove unused columns projections. + + Example: + >>> import sqlglot + >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_projections(expression).sql() + 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + # Map of Scope to all columns being selected by outer queries. + referenced_columns = defaultdict(set) + + # We build the scope tree (which is traversed in DFS postorder), then iterate + # over the result in reverse order. This should ensure that the set of selected + # columns for a particular scope are completely build by the time we get to it. + for scope in reversed(traverse_scope(expression)): + parent_selections = referenced_columns.get(scope, {SELECT_ALL}) + + if scope.expression.args.get("distinct"): + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT + parent_selections = {SELECT_ALL} + + if isinstance(scope.expression, exp.Union): + left, right = scope.union + referenced_columns[left] = parent_selections + referenced_columns[right] = parent_selections + + if isinstance(scope.expression, exp.Select): + _remove_unused_selections(scope, parent_selections) + + # Group columns by source name + selects = defaultdict(set) + for col in scope.columns: + table_name = col.table + col_name = col.name + selects[table_name].add(col_name) + + # Push the selected columns down to the next scope + for name, (_, source) in scope.selected_sources.items(): + if isinstance(source, Scope): + columns = selects.get(name) or set() + referenced_columns[source].update(columns) + + return expression + + +def _remove_unused_selections(scope, parent_selections): + order = scope.expression.args.get("order") + + if order: + # Assume columns without a qualified table are references to output columns + order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} + else: + order_refs = set() + + new_selections = [] + for selection in scope.selects: + if ( + SELECT_ALL in parent_selections + or selection.alias_or_name in parent_selections + or selection.alias_or_name in order_refs + ): + new_selections.append(selection) + + # If there are no remaining selections, just select a single constant + if not new_selections: + new_selections.append(alias("1", "_")) + + scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py new file mode 100644 index 0000000..394f49e --- /dev/null +++ b/sqlglot/optimizer/qualify_columns.py @@ -0,0 +1,422 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import traverse_scope + +SKIP_QUALIFY = (exp.Unnest, exp.Lateral) + + +def qualify_columns(expression, schema): + """ + Rewrite sqlglot AST to have fully qualified columns. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify_columns(expression, schema).sql() + 'SELECT tbl.col AS col FROM tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + schema = ensure_schema(schema) + + for scope in traverse_scope(expression): + resolver = _Resolver(scope, schema) + _pop_table_column_aliases(scope.ctes) + _pop_table_column_aliases(scope.derived_tables) + _expand_using(scope, resolver) + _expand_group_by(scope, resolver) + _expand_order_by(scope) + _qualify_columns(scope, resolver) + if not isinstance(scope.expression, SKIP_QUALIFY): + _expand_stars(scope, resolver) + _qualify_outputs(scope) + _check_unknown_tables(scope) + + return expression + + +def _pop_table_column_aliases(derived_tables): + """ + Remove table column aliases. + + (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) + """ + for derived_table in derived_tables: + if isinstance(derived_table, SKIP_QUALIFY): + continue + table_alias = derived_table.args.get("alias") + if table_alias: + table_alias.args.pop("columns", None) + + +def _expand_using(scope, resolver): + joins = list(scope.expression.find_all(exp.Join)) + names = {join.this.alias for join in joins} + ordered = [key for key in scope.selected_sources if key not in names] + + # Mapping of automatically joined column names to source names + column_tables = {} + + for join in joins: + using = join.args.get("using") + + if not using: + continue + + join_table = join.this.alias_or_name + + columns = {} + + for k in scope.selected_sources: + if k in ordered: + for column in resolver.get_source_columns(k): + if column not in columns: + columns[column] = k + + ordered.append(join_table) + join_columns = resolver.get_source_columns(join_table) + conditions = [] + + for identifier in using: + identifier = identifier.name + table = columns.get(identifier) + + if not table or identifier not in join_columns: + raise OptimizeError(f"Cannot automatically join: {identifier}") + + conditions.append( + exp.condition( + exp.EQ( + this=exp.column(identifier, table=table), + expression=exp.column(identifier, table=join_table), + ) + ) + ) + + tables = column_tables.setdefault(identifier, []) + if table not in tables: + tables.append(table) + if join_table not in tables: + tables.append(join_table) + + join.args.pop("using") + join.set("on", exp.and_(*conditions)) + + if column_tables: + for column in scope.columns: + if not column.table and column.name in column_tables: + tables = column_tables[column.name] + coalesce = [exp.column(column.name, table=table) for table in tables] + replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) + + # Ensure selects keep their output name + if isinstance(column.parent, exp.Select): + replacement = exp.alias_(replacement, alias=column.name) + + scope.replace(column, replacement) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + + # Replace references to select aliases + def transform(node, *_): + if isinstance(node, exp.Column) and not node.table: + table = resolver.get_table(node.name) + + # Source columns get priority over select aliases + if table: + node.set("table", exp.to_identifier(table)) + return node + + selects = {s.alias_or_name: s for s in scope.selects} + + select = selects.get(node.name) + if select: + scope.clear_cache() + if isinstance(select, exp.Alias): + select = select.this + return select.copy() + + return node + + group.transform(transform, copy=False) + group.set("expressions", _expand_positional_references(scope, group.expressions)) + scope.expression.set("group", group) + + +def _expand_order_by(scope): + order = scope.expression.args.get("order") + if not order: + return + + ordereds = order.expressions + for ordered, new_expression in zip( + ordereds, + _expand_positional_references(scope, (o.this for o in ordereds)), + ): + ordered.set("this", new_expression) + + +def _expand_positional_references(scope, expressions): + new_nodes = [] + for node in expressions: + if node.is_int: + try: + select = scope.selects[int(node.name) - 1] + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + if isinstance(select, exp.Alias): + select = select.this + new_nodes.append(select.copy()) + scope.clear_cache() + else: + new_nodes.append(node) + + return new_nodes + + +def _qualify_columns(scope, resolver): + """Disambiguate columns, ensuring each column specifies a source""" + for column in scope.columns: + column_table = column.table + column_name = column.name + + if ( + column_table + and column_table in scope.sources + and column_name not in resolver.get_source_columns(column_table) + ): + raise OptimizeError(f"Unknown column: {column_name}") + + if not column_table: + column_table = resolver.get_table(column_name) + + if not scope.is_subquery and not scope.is_unnest: + if column_name not in resolver.all_columns: + raise OptimizeError(f"Unknown column: {column_name}") + + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column_name}") + + # column_table can be a '' because bigquery unnest has no table alias + if column_table: + column.set("table", exp.to_identifier(column_table)) + + +def _expand_stars(scope, resolver): + """Expand stars to lists of column selections""" + + new_selections = [] + except_columns = {} + replace_columns = {} + + for expression in scope.selects: + if isinstance(expression, exp.Star): + tables = list(scope.selected_sources) + _add_except_columns(expression, tables, except_columns) + _add_replace_columns(expression, tables, replace_columns) + elif isinstance(expression, exp.Column) and isinstance( + expression.this, exp.Star + ): + tables = [expression.table] + _add_except_columns(expression.this, tables, except_columns) + _add_replace_columns(expression.this, tables, replace_columns) + else: + new_selections.append(expression) + continue + + for table in tables: + if table not in scope.sources: + raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table) + table_id = id(table) + for name in columns: + if name not in except_columns.get(table_id, set()): + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table) + new_selections.append( + alias(column, alias_) if alias_ != name else column + ) + + scope.expression.set("expressions", new_selections) + + +def _add_except_columns(expression, tables, except_columns): + except_ = expression.args.get("except") + + if not except_: + return + + columns = {e.name for e in except_} + + for table in tables: + except_columns[id(table)] = columns + + +def _add_replace_columns(expression, tables, replace_columns): + replace = expression.args.get("replace") + + if not replace: + return + + columns = {e.this.name: e.alias for e in replace} + + for table in tables: + replace_columns[id(table)] = columns + + +def _qualify_outputs(scope): + """Ensure all output columns are aliased""" + new_selections = [] + + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): + if isinstance(selection, exp.Column): + # convoluted setter because a simple selection.replace(alias) would require a copy + alias_ = alias(exp.column(""), alias=selection.name) + alias_.set("this", selection) + selection = alias_ + elif not isinstance(selection, exp.Alias): + alias_ = alias(exp.column(""), f"_col_{i}") + alias_.set("this", selection) + selection = alias_ + + if aliased_column: + selection.set("alias", exp.to_identifier(aliased_column)) + + new_selections.append(selection) + + scope.expression.set("expressions", new_selections) + + +def _check_unknown_tables(scope): + if ( + scope.external_columns + and not scope.is_unnest + and not scope.is_correlated_subquery + ): + raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") + + +class _Resolver: + """ + Helper for resolving columns. + + This is a class so we can lazily load some things and easily share them across functions. + """ + + def __init__(self, scope, schema): + self.scope = scope + self.schema = schema + self._source_columns = None + self._unambiguous_columns = None + self._all_columns = None + + def get_table(self, column_name): + """ + Get the table for a column name. + + Args: + column_name (str) + Returns: + (str) table name + """ + if self._unambiguous_columns is None: + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) + return self._unambiguous_columns.get(column_name) + + @property + def all_columns(self): + """All available columns of all sources in this scope""" + if self._all_columns is None: + self._all_columns = set( + column + for columns in self._get_all_source_columns().values() + for column in columns + ) + return self._all_columns + + def get_source_columns(self, name): + """Resolve the source columns for a given source `name`""" + if name not in self.scope.sources: + raise OptimizeError(f"Unknown table: {name}") + + source = self.scope.sources[name] + + # If referencing a table, return the columns from the schema + if isinstance(source, exp.Table): + try: + return self.schema.column_names(source) + except Exception as e: + raise OptimizeError(str(e)) from e + + # Otherwise, if referencing another scope, return that scope's named selects + return source.expression.named_selects + + def _get_all_source_columns(self): + if self._source_columns is None: + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } + return self._source_columns + + def _get_unambiguous_columns(self, source_columns): + """ + Find all the unambiguous columns in sources. + + Args: + source_columns (dict): Mapping of names to source columns + Returns: + dict: Mapping of column name to source name + """ + if not source_columns: + return {} + + source_columns = list(source_columns.items()) + + first_table, first_columns = source_columns[0] + unambiguous_columns = { + col: first_table for col in self._find_unique_columns(first_columns) + } + all_columns = set(unambiguous_columns) + + for table, columns in source_columns[1:]: + unique = self._find_unique_columns(columns) + ambiguous = set(all_columns).intersection(unique) + all_columns.update(columns) + for column in ambiguous: + unambiguous_columns.pop(column, None) + for column in unique.difference(ambiguous): + unambiguous_columns[column] = table + + return unambiguous_columns + + @staticmethod + def _find_unique_columns(columns): + """ + Find the unique columns in a list of columns. + + Example: + >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) + ['a', 'c'] + + This is necessary because duplicate column names are ambiguous. + """ + counts = {} + for column in columns: + counts[column] = counts.get(column, 0) + 1 + return {column for column, count in counts.items() if count == 1} diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py new file mode 100644 index 0000000..9f8b9f5 --- /dev/null +++ b/sqlglot/optimizer/qualify_tables.py @@ -0,0 +1,54 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.optimizer.scope import traverse_scope + + +def qualify_tables(expression, db=None, catalog=None): + """ + Rewrite sqlglot AST to have fully qualified tables. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") + >>> qualify_tables(expression, db="db").sql() + 'SELECT 1 FROM db.tbl AS tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + db (str): Database name + catalog (str): Catalog name + Returns: + sqlglot.Expression: qualified expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + for derived_table in scope.ctes + scope.derived_tables: + if not derived_table.args.get("alias"): + alias_ = f"_q_{next(sequence)}" + derived_table.set( + "alias", exp.TableAlias(this=exp.to_identifier(alias_)) + ) + scope.rename_source(None, alias_) + + for source in scope.sources.values(): + if isinstance(source, exp.Table): + identifier = isinstance(source.this, exp.Identifier) + + if identifier: + if not source.args.get("db"): + source.set("db", exp.to_identifier(db)) + if not source.args.get("catalog"): + source.set("catalog", exp.to_identifier(catalog)) + + if not isinstance(source.parent, exp.Alias): + source.replace( + alias( + source.copy(), + source.this if identifier else f"_q_{next(sequence)}", + table=True, + ) + ) + + return expression diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py new file mode 100644 index 0000000..17623cc --- /dev/null +++ b/sqlglot/optimizer/quote_identities.py @@ -0,0 +1,25 @@ +from sqlglot import exp + + +def quote_identities(expression): + """ + Rewrite sqlglot AST to ensure all identities are quoted. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x") + >>> quote_identities(expression).sql() + 'SELECT "x"."a" AS "a" FROM "db"."x"' + + Args: + expression (sqlglot.Expression): expression to quote + Returns: + sqlglot.Expression: quoted expression + """ + + def qualify(node): + if isinstance(node, exp.Identifier): + node.set("quoted", True) + return node + + return expression.transform(qualify, copy=False) diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py new file mode 100644 index 0000000..9968108 --- /dev/null +++ b/sqlglot/optimizer/schema.py @@ -0,0 +1,129 @@ +import abc + +from sqlglot import exp +from sqlglot.errors import OptimizeError +from sqlglot.helper import csv_reader + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @abc.abstractmethod + def column_names(self, table): + """ + Get the column names for a table. + + Args: + table (sqlglot.expressions.Table): Table expression instance + Returns: + list[str]: list of column names + """ + + +class MappingSchema(Schema): + """ + Schema based on a nested mapping. + + Args: + schema (dict): Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + """ + + def __init__(self, schema): + self.schema = schema + + depth = _dict_depth(schema) + + if not depth: # {} + self.supported_table_args = [] + elif depth == 2: # {table: {col: type}} + self.supported_table_args = ("this",) + elif depth == 3: # {db: {table: {col: type}}} + self.supported_table_args = ("db", "this") + elif depth == 4: # {catalog: {db: {table: {col: type}}}} + self.supported_table_args = ("catalog", "db", "this") + else: + raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + + self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) + + def column_names(self, table): + if not isinstance(table.this, exp.Identifier): + return fs_get(table) + + args = tuple(table.text(p) for p in self.supported_table_args) + + for forbidden in self.forbidden_args: + if table.text(forbidden): + raise ValueError( + f"Schema doesn't support {forbidden}. Received: {table.sql()}" + ) + return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + +def ensure_schema(schema): + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema) + + +def fs_get(table): + name = table.this.name.upper() + + if name.upper() == "READ_CSV": + with csv_reader(table) as reader: + return next(reader) + + raise ValueError(f"Cannot read schema for {table}") + + +def _nested_get(d, *path): + """ + Get a value for a nested dictionary. + + Args: + d (dict): dictionary + *path (tuple[str, str]): tuples of (name, key) + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + """ + for name, key in path: + d = d.get(key) + if d is None: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}") + return d + + +def _dict_depth(d): + """ + Get the nesting depth of a dictionary. + + For example: + >>> _dict_depth(None) + 0 + >>> _dict_depth({}) + 1 + >>> _dict_depth({"a": "b"}) + 1 + >>> _dict_depth({"a": {}}) + 2 + >>> _dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + _dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py new file mode 100644 index 0000000..f6f59e8 --- /dev/null +++ b/sqlglot/optimizer/scope.py @@ -0,0 +1,438 @@ +from copy import copy +from enum import Enum, auto + +from sqlglot import exp +from sqlglot.errors import OptimizeError + + +class ScopeType(Enum): + ROOT = auto() + SUBQUERY = auto() + DERIVED_TABLE = auto() + CTE = auto() + UNION = auto() + UNNEST = auto() + + +class Scope: + """ + Selection scope. + + Attributes: + expression (exp.Select|exp.Union): Root expression of this scope + sources (dict[str, exp.Table|Scope]): Mapping of source name to either + a Table expression or another Scope instance. For example: + SELECT * FROM x {"x": Table(this="x")} + SELECT * FROM x AS y {"y": Table(this="x")} + SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + outer_column_list (list[str]): If this is a derived table or CTE, and the outer query + defines a column list of it's alias of this scope, this is that list of columns. + For example: + SELECT * FROM (SELECT ...) AS y(col1, col2) + The inner query would have `["col1", "col2"]` for its `outer_column_list` + parent (Scope): Parent scope + scope_type (ScopeType): Type of this scope, relative to it's parent + subquery_scopes (list[Scope]): List of all child scopes for subqueries. + This does not include derived tables or CTEs. + union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be + a tuple of the left and right child scopes. + """ + + def __init__( + self, + expression, + sources=None, + outer_column_list=None, + parent=None, + scope_type=ScopeType.ROOT, + ): + self.expression = expression + self.sources = sources or {} + self.outer_column_list = outer_column_list or [] + self.parent = parent + self.scope_type = scope_type + self.subquery_scopes = [] + self.union = None + self.clear_cache() + + def clear_cache(self): + self._collected = False + self._raw_columns = None + self._derived_tables = None + self._tables = None + self._ctes = None + self._subqueries = None + self._selected_sources = None + self._columns = None + self._external_columns = None + + def branch(self, expression, scope_type, add_sources=None, **kwargs): + """Branch from the current scope to a new, inner scope""" + sources = copy(self.sources) + if add_sources: + sources.update(add_sources) + return Scope( + expression=expression.unnest(), + sources=sources, + parent=self, + scope_type=scope_type, + **kwargs, + ) + + def _collect(self): + self._tables = [] + self._ctes = [] + self._subqueries = [] + self._derived_tables = [] + self._raw_columns = [] + + # We'll use this variable to pass state into the dfs generator. + # Whenever we set it to True, we exclude a subtree from traversal. + prune = False + + for node, parent, _ in self.expression.dfs(prune=lambda *_: prune): + prune = False + + if node is self.expression: + continue + if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + self._raw_columns.append(node) + elif isinstance(node, exp.Table): + self._tables.append(node) + elif isinstance(node, (exp.Unnest, exp.Lateral)): + self._derived_tables.append(node) + elif isinstance(node, exp.CTE): + self._ctes.append(node) + prune = True + elif isinstance(node, exp.Subquery) and isinstance( + parent, (exp.From, exp.Join) + ): + self._derived_tables.append(node) + prune = True + elif isinstance(node, exp.Subqueryable): + self._subqueries.append(node) + prune = True + + self._collected = True + + def _ensure_collected(self): + if not self._collected: + self._collect() + + def replace(self, old, new): + """ + Replace `old` with `new`. + + This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. + + Args: + old (exp.Expression): old node + new (exp.Expression): new node + """ + old.replace(new) + self.clear_cache() + + @property + def tables(self): + """ + List of tables in this scope. + + Returns: + list[exp.Table]: tables + """ + self._ensure_collected() + return self._tables + + @property + def ctes(self): + """ + List of CTEs in this scope. + + Returns: + list[exp.CTE]: ctes + """ + self._ensure_collected() + return self._ctes + + @property + def derived_tables(self): + """ + List of derived tables in this scope. + + For example: + SELECT * FROM (SELECT ...) <- that's a derived table + + Returns: + list[exp.Subquery]: derived tables + """ + self._ensure_collected() + return self._derived_tables + + @property + def subqueries(self): + """ + List of subqueries in this scope. + + For example: + SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery + + Returns: + list[exp.Subqueryable]: subqueries + """ + self._ensure_collected() + return self._subqueries + + @property + def columns(self): + """ + List of columns in this scope. + + Returns: + list[exp.Column]: Column instances in this scope, plus any + Columns that reference this scope from correlated subqueries. + """ + if self._columns is None: + self._ensure_collected() + columns = self._raw_columns + + external_columns = [ + column + for scope in self.subquery_scopes + for column in scope.external_columns + ] + + named_outputs = {e.alias_or_name for e in self.expression.expressions} + + self._columns = [ + c + for c in columns + external_columns + if not ( + c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs + ) + ] + return self._columns + + @property + def selected_sources(self): + """ + Mapping of nodes and sources that are actually selected from in this scope. + + That is, all tables in a schema are selectable at any point. But a + table only becomes a selected source if it's included in a FROM or JOIN clause. + + Returns: + dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes + """ + if self._selected_sources is None: + referenced_names = [] + + for table in self.tables: + referenced_names.append( + ( + table.parent.alias + if isinstance(table.parent, exp.Alias) + else table.name, + table, + ) + ) + for derived_table in self.derived_tables: + referenced_names.append((derived_table.alias, derived_table.unnest())) + + result = {} + + for name, node in referenced_names: + if name in self.sources: + result[name] = (node, self.sources[name]) + + self._selected_sources = result + return self._selected_sources + + @property + def selects(self): + """ + Select expressions of this scope. + + For example, for the following expression: + SELECT 1 as a, 2 as b FROM x + + The outputs are the "1 as a" and "2 as b" expressions. + + Returns: + list[exp.Expression]: expressions + """ + if isinstance(self.expression, exp.Union): + return [] + return self.expression.selects + + @property + def external_columns(self): + """ + Columns that appear to reference sources in outer scopes. + + Returns: + list[exp.Column]: Column instances that don't reference + sources in the current scope. + """ + if self._external_columns is None: + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] + return self._external_columns + + def source_columns(self, source_name): + """ + Get all columns in the current scope for a particular source. + + Args: + source_name (str): Name of the source + Returns: + list[exp.Column]: Column instances that reference `source_name` + """ + return [column for column in self.columns if column.table == source_name] + + @property + def is_subquery(self): + """Determine if this scope is a subquery""" + return self.scope_type == ScopeType.SUBQUERY + + @property + def is_unnest(self): + """Determine if this scope is an unnest""" + return self.scope_type == ScopeType.UNNEST + + @property + def is_correlated_subquery(self): + """Determine if this scope is a correlated subquery""" + return bool(self.is_subquery and self.external_columns) + + def rename_source(self, old_name, new_name): + """Rename a source in this scope""" + columns = self.sources.pop(old_name or "", []) + self.sources[new_name] = columns + + +def traverse_scope(expression): + """ + Traverse an expression by it's "scopes". + + "Scope" represents the current context of a Select statement. + + This is helpful for optimizing queries, where we need more information than + the expression tree itself. For example, we might care about the source + names within a subquery. Returns a list because a generator could result in + incomplete properties which is confusing. + + Examples: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") + >>> scopes = traverse_scope(expression) + >>> scopes[0].expression.sql(), list(scopes[0].sources) + ('SELECT a FROM x', ['x']) + >>> scopes[1].expression.sql(), list(scopes[1].sources) + ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) + + Args: + expression (exp.Expression): expression to traverse + Returns: + List[Scope]: scope instances + """ + return list(_traverse_scope(Scope(expression))) + + +def _traverse_scope(scope): + if isinstance(scope.expression, exp.Select): + yield from _traverse_select(scope) + elif isinstance(scope.expression, exp.Union): + yield from _traverse_union(scope) + elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)): + pass + elif isinstance(scope.expression, exp.Subquery): + yield from _traverse_subqueries(scope) + else: + raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") + yield scope + + +def _traverse_select(scope): + yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) + yield from _traverse_subqueries(scope) + yield from _traverse_derived_tables( + scope.derived_tables, scope, ScopeType.DERIVED_TABLE + ) + _add_table_sources(scope) + + +def _traverse_union(scope): + yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) + + # The last scope to be yield should be the top most scope + left = None + for left in _traverse_scope( + scope.branch(scope.expression.left, scope_type=ScopeType.UNION) + ): + yield left + + right = None + for right in _traverse_scope( + scope.branch(scope.expression.right, scope_type=ScopeType.UNION) + ): + yield right + + scope.union = (left, right) + + +def _traverse_derived_tables(derived_tables, scope, scope_type): + sources = {} + + for derived_table in derived_tables: + for child_scope in _traverse_scope( + scope.branch( + derived_table + if isinstance(derived_table, (exp.Unnest, exp.Lateral)) + else derived_table.this, + add_sources=sources if scope_type == ScopeType.CTE else None, + outer_column_list=derived_table.alias_column_names, + scope_type=ScopeType.UNNEST + if isinstance(derived_table, exp.Unnest) + else scope_type, + ) + ): + yield child_scope + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + sources[derived_table.alias] = child_scope + scope.sources.update(sources) + + +def _add_table_sources(scope): + sources = {} + for table in scope.tables: + table_name = table.name + + if isinstance(table.parent, exp.Alias): + source_name = table.parent.alias + else: + source_name = table_name + + if table_name in scope.sources: + # This is a reference to a parent source (e.g. a CTE), not an actual table. + scope.sources[source_name] = scope.sources[table_name] + elif source_name in scope.sources: + raise OptimizeError(f"Duplicate table name: {source_name}") + else: + sources[source_name] = table + + scope.sources.update(sources) + + +def _traverse_subqueries(scope): + for subquery in scope.subqueries: + top = None + for child_scope in _traverse_scope( + scope.branch(subquery, scope_type=ScopeType.SUBQUERY) + ): + yield child_scope + top = child_scope + scope.subquery_scopes.append(top) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py new file mode 100644 index 0000000..6771153 --- /dev/null +++ b/sqlglot/optimizer/simplify.py @@ -0,0 +1,383 @@ +import datetime +import functools +import itertools +from collections import deque +from decimal import Decimal + +from sqlglot import exp +from sqlglot.expressions import FALSE, NULL, TRUE +from sqlglot.generator import Generator +from sqlglot.helper import while_changing + +GENERATOR = Generator(normalize=True, identify=True) + + +def simplify(expression): + """ + Rewrite sqlglot AST to simplify expressions. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("TRUE AND TRUE") + >>> simplify(expression).sql() + 'TRUE' + + Args: + expression (sqlglot.Expression): expression to simplify + Returns: + sqlglot.Expression: simplified expression + """ + + def _simplify(expression, root=True): + node = expression + node = uniq_sort(node) + node = absorb_and_eliminate(node) + exp.replace_children(node, lambda e: _simplify(e, False)) + node = simplify_not(node) + node = flatten(node) + node = simplify_connectors(node) + node = remove_compliments(node) + node.parent = expression.parent + node = simplify_literals(node) + node = simplify_parens(node) + if root: + expression.replace(node) + return node + + expression = while_changing(expression, _simplify) + remove_where_true(expression) + return expression + + +def simplify_not(expression): + """ + Demorgan's Law + NOT (x OR y) -> NOT x AND NOT y + NOT (x AND y) -> NOT x OR NOT y + """ + if isinstance(expression, exp.Not): + if isinstance(expression.this, exp.Paren): + condition = expression.this.unnest() + if isinstance(condition, exp.And): + return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) + if isinstance(condition, exp.Or): + return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + if always_true(expression.this): + return FALSE + if expression.this == FALSE: + return TRUE + if isinstance(expression.this, exp.Not): + # double negation + # NOT NOT x -> x + return expression.this.this + return expression + + +def flatten(expression): + """ + A AND (B AND C) -> A AND B AND C + A OR (B OR C) -> A OR B OR C + """ + if isinstance(expression, exp.Connector): + for node in expression.args.values(): + child = node.unnest() + if isinstance(child, expression.__class__): + node.replace(child) + return expression + + +def simplify_connectors(expression): + if isinstance(expression, exp.Connector): + left = expression.left + right = expression.right + + if left == right: + return left + + if isinstance(expression, exp.And): + if NULL in (left, right): + return NULL + if FALSE in (left, right): + return FALSE + if always_true(left) and always_true(right): + return TRUE + if always_true(left): + return right + if always_true(right): + return left + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return TRUE + if left == FALSE and right == FALSE: + return FALSE + if ( + (left == NULL and right == NULL) + or (left == NULL and right == FALSE) + or (left == FALSE and right == NULL) + ): + return NULL + if left == FALSE: + return right + if right == FALSE: + return left + return expression + + +def remove_compliments(expression): + """ + Removing compliments. + + A AND NOT A -> FALSE + A OR NOT A -> TRUE + """ + if isinstance(expression, exp.Connector): + compliment = FALSE if isinstance(expression, exp.And) else TRUE + + for a, b in itertools.permutations(expression.flatten(), 2): + if is_complement(a, b): + return compliment + return expression + + +def uniq_sort(expression): + """ + Uniq and sort a connector. + + C AND A AND B AND B -> A AND B AND C + """ + if isinstance(expression, exp.Connector): + result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ + flattened = tuple(expression.flatten()) + deduped = {GENERATOR.generate(e): e for e in flattened} + arr = tuple(deduped.items()) + + # check if the operands are already sorted, if not sort them + # A AND C AND B -> A AND B AND C + for i, (sql, e) in enumerate(arr[1:]): + if sql < arr[i][0]: + expression = result_func(*(deduped[sql] for sql in sorted(deduped))) + break + else: + # we didn't have to sort but maybe we need to dedup + if len(deduped) < len(flattened): + expression = result_func(*deduped.values()) + + return expression + + +def absorb_and_eliminate(expression): + """ + absorption: + A AND (A OR B) -> A + A OR (A AND B) -> A + A AND (NOT A OR B) -> A AND B + A OR (NOT A AND B) -> A OR B + elimination: + (A AND B) OR (A AND NOT B) -> A + (A OR B) AND (A OR NOT B) -> A + """ + if isinstance(expression, exp.Connector): + kind = exp.Or if isinstance(expression, exp.And) else exp.And + + for a, b in itertools.permutations(expression.flatten(), 2): + if isinstance(a, kind): + aa, ab = a.unnest_operands() + + # absorb + if is_complement(b, aa): + aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) + elif is_complement(b, ab): + ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) + elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( + a.flatten() + ): + a.replace(exp.FALSE if kind == exp.And else exp.TRUE) + elif isinstance(b, kind): + # eliminate + rhs = b.unnest_operands() + ba, bb = rhs + + if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): + a.replace(aa) + b.replace(aa) + elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): + a.replace(ab) + b.replace(ab) + + return expression + + +def simplify_literals(expression): + if isinstance(expression, exp.Binary): + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = _simplify_binary(expression, a, b) + + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) + elif isinstance(expression, exp.Neg): + this = expression.this + if this.is_number: + value = this.name + if value[0] == "-": + return exp.Literal.number(value[1:]) + return exp.Literal.number(f"-{value}") + + return expression + + +def _simplify_binary(expression, a, b): + if isinstance(expression, exp.Is): + if isinstance(b, exp.Not): + c = b.this + not_ = True + else: + c = b + not_ = False + + if c == NULL: + if isinstance(a, exp.Literal): + return TRUE if not_ else FALSE + if a == NULL: + return FALSE if not_ else TRUE + elif NULL in (a, b): + return NULL + + if isinstance(expression, exp.EQ) and a == b: + return TRUE + + if a.is_number and b.is_number: + a = int(a.name) if a.is_int else Decimal(a.name) + b = int(b.name) if b.is_int else Decimal(b.name) + + if isinstance(expression, exp.Add): + return exp.Literal.number(a + b) + if isinstance(expression, exp.Sub): + return exp.Literal.number(a - b) + if isinstance(expression, exp.Mul): + return exp.Literal.number(a * b) + if isinstance(expression, exp.Div): + if isinstance(a, int) and isinstance(b, int): + return exp.Literal.number(a // b) + return exp.Literal.number(a / b) + + boolean = eval_boolean(expression, a, b) + + if boolean: + return boolean + elif a.is_string and b.is_string: + boolean = eval_boolean(expression, a, b) + + if boolean: + return boolean + elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): + a, b = extract_date(a), extract_interval(b) + if b: + if isinstance(expression, exp.Add): + return date_literal(a + b) + if isinstance(expression, exp.Sub): + return date_literal(a - b) + elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): + a, b = extract_interval(a), extract_date(b) + # you cannot subtract a date from an interval + if a and isinstance(expression, exp.Add): + return date_literal(a + b) + + return None + + +def simplify_parens(expression): + if ( + isinstance(expression, exp.Paren) + and not isinstance(expression.this, exp.Select) + and ( + not isinstance(expression.parent, (exp.Condition, exp.Binary)) + or isinstance(expression.this, (exp.Is, exp.Like)) + or not isinstance(expression.this, exp.Binary) + ) + ): + return expression.this + return expression + + +def remove_where_true(expression): + for where in expression.find_all(exp.Where): + if always_true(where.this): + where.parent.set("where", None) + for join in expression.find_all(exp.Join): + if always_true(join.args.get("on")): + join.set("kind", "CROSS") + join.set("on", None) + + +def always_true(expression): + return expression == TRUE or isinstance(expression, exp.Literal) + + +def is_complement(a, b): + return isinstance(b, exp.Not) and b.this == a + + +def eval_boolean(expression, a, b): + if isinstance(expression, (exp.EQ, exp.Is)): + return boolean_literal(a == b) + if isinstance(expression, exp.NEQ): + return boolean_literal(a != b) + if isinstance(expression, exp.GT): + return boolean_literal(a > b) + if isinstance(expression, exp.GTE): + return boolean_literal(a >= b) + if isinstance(expression, exp.LT): + return boolean_literal(a < b) + if isinstance(expression, exp.LTE): + return boolean_literal(a <= b) + return None + + +def extract_date(cast): + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + return None + + +def extract_interval(interval): + try: + from dateutil.relativedelta import relativedelta + except ModuleNotFoundError: + return None + + n = int(interval.name) + unit = interval.text("unit").lower() + + if unit == "year": + return relativedelta(years=n) + if unit == "month": + return relativedelta(months=n) + if unit == "week": + return relativedelta(weeks=n) + if unit == "day": + return relativedelta(days=n) + return None + + +def date_literal(date): + return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + + +def boolean_literal(condition): + return TRUE if condition else FALSE diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py new file mode 100644 index 0000000..55c81c5 --- /dev/null +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -0,0 +1,220 @@ +import itertools + +from sqlglot import exp +from sqlglot.optimizer.scope import traverse_scope + + +def unnest_subqueries(expression): + """ + Rewrite sqlglot AST to convert some predicates with subqueries into joins. + + Convert the subquery into a group by so it is not a many to many left join. + Unnesting can only occur if the subquery does not have LIMIT or OFFSET. + Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") + >>> unnest_subqueries(expression).sql() + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ + AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + + Args: + expression (sqlglot.Expression): expression to unnest + Returns: + sqlglot.Expression: unnested expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + select = scope.expression + parent = select.parent_select + if scope.external_columns: + decorrelate(select, parent, scope.external_columns, sequence) + else: + unnest(select, parent, sequence) + + return expression + + +def unnest(select, parent_select, sequence): + predicate = select.find_ancestor(exp.In, exp.Any) + + if not predicate or parent_select is not predicate.parent_select: + return + + if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + return + + if isinstance(predicate, exp.Any): + predicate = predicate.find_ancestor(exp.EQ) + + if not predicate or parent_select is not predicate.parent_select: + return + + column = _other_operand(predicate) + value = select.selects[0] + alias = _alias(sequence) + + on = exp.condition(f'{column} = "{alias}"."{value.alias}"') + _replace(predicate, f"NOT {on.right} IS NULL") + + parent_select.join( + select.group_by(value.this, copy=False), + on=on, + join_type="LEFT", + join_alias=alias, + copy=False, + ) + + +def decorrelate(select, parent_select, external_columns, sequence): + where = select.args.get("where") + + if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): + return + + table_alias = _alias(sequence) + keys = [] + + # for all external columns in the where statement, + # split out the relevant data to convert it into a join + for column in external_columns: + if column.find_ancestor(exp.Where) is not where: + return + + predicate = column.find_ancestor(exp.Predicate) + + if not predicate or predicate.find_ancestor(exp.Where) is not where: + return + + if isinstance(predicate, exp.Binary): + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) + else: + return + + keys.append((key, column, predicate)) + + if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): + return + + value = select.selects[0] + key_aliases = {} + group_by = [] + + for key, _, predicate in keys: + # if we filter on the value of the subquery, it needs to be unique + if key == value.this: + key_aliases[key] = value.alias + group_by.append(key) + else: + if key not in key_aliases: + key_aliases[key] = _alias(sequence) + # all predicates that are equalities must also be in the unique + # so that we don't do a many to many join + if isinstance(predicate, exp.EQ) and key not in group_by: + group_by.append(key) + + parent_predicate = select.find_ancestor(exp.Predicate) + + # if the value of the subquery is not an agg or a key, we need to collect it into an array + # so that it can be grouped + if not value.find(exp.AggFunc) and value.this not in group_by: + select.select( + f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False + ) + + # exists queries should not have any selects as it only checks if there are any rows + # all selects will be added by the optimizer and only used for join keys + if isinstance(parent_predicate, exp.Exists): + select.args["expressions"] = [] + + for key, alias in key_aliases.items(): + if key in group_by: + # add all keys to the projections of the subquery + # so that we can use it as a join key + if isinstance(parent_predicate, exp.Exists) or key != value.this: + select.select(f"{key} AS {alias}", copy=False) + else: + select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + + alias = exp.column(value.alias, table_alias) + other = _other_operand(parent_predicate) + + if isinstance(parent_predicate, exp.Exists): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") + else: + parent_predicate = _replace(parent_predicate, "TRUE") + elif isinstance(parent_predicate, exp.All): + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.Any): + if value.this in group_by: + parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.In): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, + f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", + ) + else: + select.parent.replace(alias) + + for key, column, predicate in keys: + predicate.replace(exp.TRUE) + nested = exp.column(key_aliases[key], table_alias) + + if key in group_by: + key.replace(nested) + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) + elif isinstance(predicate, exp.EQ): + parent_predicate = _replace( + parent_predicate, + f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", + ) + else: + key.replace(exp.to_identifier("_x")) + parent_predicate = _replace( + parent_predicate, + f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', + ) + + parent_select.join( + select.group_by(*group_by, copy=False), + on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], + join_type="LEFT", + join_alias=table_alias, + copy=False, + ) + + +def _alias(sequence): + return f"_u_{next(sequence)}" + + +def _replace(expression, condition): + return expression.replace(exp.condition(condition)) + + +def _other_operand(expression): + if isinstance(expression, exp.In): + return expression.this + + if isinstance(expression, exp.Binary): + return expression.right if expression.arg_key == "this" else expression.left + + return None diff --git a/sqlglot/parser.py b/sqlglot/parser.py new file mode 100644 index 0000000..9396c50 --- /dev/null +++ b/sqlglot/parser.py @@ -0,0 +1,2190 @@ +import logging + +from sqlglot import exp +from sqlglot.errors import ErrorLevel, ParseError, concat_errors +from sqlglot.helper import apply_index_offset, ensure_list, list_get +from sqlglot.tokens import Token, Tokenizer, TokenType + +logger = logging.getLogger("sqlglot") + + +class Parser: + """ + Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` + and produces a parsed syntax tree. + + Args + error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE. + error_message_context (int): determines the amount of context to capture from + a query string when displaying the error message (in number of characters). + Default: 50. + index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list + Default: 0 + alias_post_tablesample (bool): If the table alias comes after tablesample + Default: False + max_errors (int): Maximum number of error messages to include in a raised ParseError. + This is only relevant if error_level is ErrorLevel.RAISE. + Default: 3 + null_ordering (str): Indicates the default null ordering method to use if not explicitly set. + Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". + Default: "nulls_are_small" + """ + + FUNCTIONS = { + **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, + "DATE_TO_DATE_STR": lambda args: exp.Cast( + this=list_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + "TIME_TO_TIME_STR": lambda args: exp.Cast( + this=list_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( + this=exp.Cast( + this=list_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + start=exp.Literal.number(1), + length=exp.Literal.number(10), + ), + } + + NO_PAREN_FUNCTIONS = { + TokenType.CURRENT_DATE: exp.CurrentDate, + TokenType.CURRENT_DATETIME: exp.CurrentDate, + TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, + } + + NESTED_TYPE_TOKENS = { + TokenType.ARRAY, + TokenType.MAP, + TokenType.STRUCT, + TokenType.NULLABLE, + } + + TYPE_TOKENS = { + TokenType.BOOLEAN, + TokenType.TINYINT, + TokenType.SMALLINT, + TokenType.INT, + TokenType.BIGINT, + TokenType.FLOAT, + TokenType.DOUBLE, + TokenType.CHAR, + TokenType.NCHAR, + TokenType.VARCHAR, + TokenType.NVARCHAR, + TokenType.TEXT, + TokenType.BINARY, + TokenType.JSON, + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + TokenType.DATETIME, + TokenType.DATE, + TokenType.DECIMAL, + TokenType.UUID, + TokenType.GEOGRAPHY, + *NESTED_TYPE_TOKENS, + } + + SUBQUERY_PREDICATES = { + TokenType.ANY: exp.Any, + TokenType.ALL: exp.All, + TokenType.EXISTS: exp.Exists, + TokenType.SOME: exp.Any, + } + + RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT} + + ID_VAR_TOKENS = { + TokenType.VAR, + TokenType.ALTER, + TokenType.BEGIN, + TokenType.BUCKET, + TokenType.CACHE, + TokenType.COLLATE, + TokenType.COMMIT, + TokenType.CONSTRAINT, + TokenType.CONVERT, + TokenType.DEFAULT, + TokenType.DELETE, + TokenType.ENGINE, + TokenType.ESCAPE, + TokenType.EXPLAIN, + TokenType.FALSE, + TokenType.FIRST, + TokenType.FOLLOWING, + TokenType.FORMAT, + TokenType.FUNCTION, + TokenType.IF, + TokenType.INDEX, + TokenType.ISNULL, + TokenType.INTERVAL, + TokenType.LAZY, + TokenType.LOCATION, + TokenType.NEXT, + TokenType.ONLY, + TokenType.OPTIMIZE, + TokenType.OPTIONS, + TokenType.ORDINALITY, + TokenType.PERCENT, + TokenType.PRECEDING, + TokenType.RANGE, + TokenType.REFERENCES, + TokenType.ROWS, + TokenType.SCHEMA_COMMENT, + TokenType.SET, + TokenType.SHOW, + TokenType.STORED, + TokenType.TABLE, + TokenType.TABLE_FORMAT, + TokenType.TEMPORARY, + TokenType.TOP, + TokenType.TRUNCATE, + TokenType.TRUE, + TokenType.UNBOUNDED, + TokenType.UNIQUE, + TokenType.PROPERTIES, + *SUBQUERY_PREDICATES, + *TYPE_TOKENS, + } + + CASTS = { + TokenType.CAST, + TokenType.TRY_CAST, + } + + FUNC_TOKENS = { + TokenType.CONVERT, + TokenType.CURRENT_DATE, + TokenType.CURRENT_DATETIME, + TokenType.CURRENT_TIMESTAMP, + TokenType.CURRENT_TIME, + TokenType.EXTRACT, + TokenType.FILTER, + TokenType.FIRST, + TokenType.FORMAT, + TokenType.ISNULL, + TokenType.OFFSET, + TokenType.PRIMARY_KEY, + TokenType.REPLACE, + TokenType.ROW, + TokenType.UNNEST, + TokenType.VAR, + TokenType.LEFT, + TokenType.RIGHT, + TokenType.DATE, + TokenType.DATETIME, + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + *CASTS, + *NESTED_TYPE_TOKENS, + *SUBQUERY_PREDICATES, + } + + CONJUNCTION = { + TokenType.AND: exp.And, + TokenType.OR: exp.Or, + } + + EQUALITY = { + TokenType.EQ: exp.EQ, + TokenType.NEQ: exp.NEQ, + } + + COMPARISON = { + TokenType.GT: exp.GT, + TokenType.GTE: exp.GTE, + TokenType.LT: exp.LT, + TokenType.LTE: exp.LTE, + } + + BITWISE = { + TokenType.AMP: exp.BitwiseAnd, + TokenType.CARET: exp.BitwiseXor, + TokenType.PIPE: exp.BitwiseOr, + TokenType.DPIPE: exp.DPipe, + } + + TERM = { + TokenType.DASH: exp.Sub, + TokenType.PLUS: exp.Add, + TokenType.MOD: exp.Mod, + } + + FACTOR = { + TokenType.DIV: exp.IntDiv, + TokenType.SLASH: exp.Div, + TokenType.STAR: exp.Mul, + } + + TIMESTAMPS = { + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + } + + SET_OPERATIONS = { + TokenType.UNION, + TokenType.INTERSECT, + TokenType.EXCEPT, + } + + JOIN_SIDES = { + TokenType.LEFT, + TokenType.RIGHT, + TokenType.FULL, + } + + JOIN_KINDS = { + TokenType.INNER, + TokenType.OUTER, + TokenType.CROSS, + } + + COLUMN_OPERATORS = { + TokenType.DOT: None, + TokenType.ARROW: lambda self, this, path: self.expression( + exp.JSONExtract, + this=this, + path=path, + ), + TokenType.DARROW: lambda self, this, path: self.expression( + exp.JSONExtractScalar, + this=this, + path=path, + ), + TokenType.HASH_ARROW: lambda self, this, path: self.expression( + exp.JSONBExtract, + this=this, + path=path, + ), + TokenType.DHASH_ARROW: lambda self, this, path: self.expression( + exp.JSONBExtractScalar, + this=this, + path=path, + ), + } + + EXPRESSION_PARSERS = { + exp.DataType: lambda self: self._parse_types(), + exp.From: lambda self: self._parse_from(), + exp.Group: lambda self: self._parse_group(), + exp.Lateral: lambda self: self._parse_lateral(), + exp.Join: lambda self: self._parse_join(), + exp.Order: lambda self: self._parse_order(), + exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), + exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), + exp.Lambda: lambda self: self._parse_lambda(), + exp.Limit: lambda self: self._parse_limit(), + exp.Offset: lambda self: self._parse_offset(), + exp.TableAlias: lambda self: self._parse_table_alias(), + exp.Table: lambda self: self._parse_table(), + exp.Condition: lambda self: self._parse_conjunction(), + exp.Expression: lambda self: self._parse_statement(), + exp.Properties: lambda self: self._parse_properties(), + "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), + } + + STATEMENT_PARSERS = { + TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DROP: lambda self: self._parse_drop(), + TokenType.INSERT: lambda self: self._parse_insert(), + TokenType.UPDATE: lambda self: self._parse_update(), + TokenType.DELETE: lambda self: self._parse_delete(), + TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.UNCACHE: lambda self: self._parse_uncache(), + } + + PRIMARY_PARSERS = { + TokenType.STRING: lambda _, token: exp.Literal.string(token.text), + TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), + TokenType.STAR: lambda self, _: exp.Star( + **{"except": self._parse_except(), "replace": self._parse_replace()} + ), + TokenType.NULL: lambda *_: exp.Null(), + TokenType.TRUE: lambda *_: exp.Boolean(this=True), + TokenType.FALSE: lambda *_: exp.Boolean(this=False), + TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), + TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), + TokenType.INTRODUCER: lambda self, token: self.expression( + exp.Introducer, + this=token.text, + expression=self._parse_var_or_string(), + ), + } + + RANGE_PARSERS = { + TokenType.BETWEEN: lambda self, this: self._parse_between(this), + TokenType.IN: lambda self, this: self._parse_in(this), + TokenType.IS: lambda self, this: self._parse_is(this), + TokenType.LIKE: lambda self, this: self._parse_escape( + self.expression(exp.Like, this=this, expression=self._parse_type()) + ), + TokenType.ILIKE: lambda self, this: self._parse_escape( + self.expression(exp.ILike, this=this, expression=self._parse_type()) + ), + TokenType.RLIKE: lambda self, this: self.expression( + exp.RegexpLike, this=this, expression=self._parse_type() + ), + } + + PROPERTY_PARSERS = { + TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), + TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), + TokenType.COLLATE: lambda self: self._parse_collate(), + TokenType.ENGINE: lambda self: self._parse_engine(), + TokenType.FORMAT: lambda self: self._parse_format(), + TokenType.LOCATION: lambda self: self.expression( + exp.LocationProperty, + this=exp.Literal.string("LOCATION"), + value=self._parse_string(), + ), + TokenType.PARTITIONED_BY: lambda self: self.expression( + exp.PartitionedByProperty, + this=exp.Literal.string("PARTITIONED_BY"), + value=self._parse_schema(), + ), + TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), + TokenType.STORED: lambda self: self._parse_stored(), + TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(), + TokenType.USING: lambda self: self._parse_table_format(), + } + + CONSTRAINT_PARSERS = { + TokenType.CHECK: lambda self: self._parse_check(), + TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), + TokenType.UNIQUE: lambda self: self._parse_unique(), + } + + NO_PAREN_FUNCTION_PARSERS = { + TokenType.CASE: lambda self: self._parse_case(), + TokenType.IF: lambda self: self._parse_if(), + } + + FUNCTION_PARSERS = { + TokenType.CONVERT: lambda self, _: self._parse_convert(), + TokenType.EXTRACT: lambda self, _: self._parse_extract(), + **{ + token_type: lambda self, token_type: self._parse_cast( + self.STRICT_CAST and token_type == TokenType.CAST + ) + for token_type in CASTS + }, + } + + QUERY_MODIFIER_PARSERS = { + "laterals": lambda self: self._parse_laterals(), + "joins": lambda self: self._parse_joins(), + "where": lambda self: self._parse_where(), + "group": lambda self: self._parse_group(), + "having": lambda self: self._parse_having(), + "qualify": lambda self: self._parse_qualify(), + "window": lambda self: self._match(TokenType.WINDOW) + and self._parse_window(self._parse_id_var(), alias=True), + "distribute": lambda self: self._parse_sort( + TokenType.DISTRIBUTE_BY, exp.Distribute + ), + "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), + "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), + "order": lambda self: self._parse_order(), + "limit": lambda self: self._parse_limit(), + "offset": lambda self: self._parse_offset(), + } + + CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} + + STRICT_CAST = True + + __slots__ = ( + "error_level", + "error_message_context", + "sql", + "errors", + "index_offset", + "unnest_column_only", + "alias_post_tablesample", + "max_errors", + "null_ordering", + "_tokens", + "_chunks", + "_index", + "_curr", + "_next", + "_prev", + "_greedy_subqueries", + ) + + def __init__( + self, + error_level=None, + error_message_context=100, + index_offset=0, + unnest_column_only=False, + alias_post_tablesample=False, + max_errors=3, + null_ordering=None, + ): + self.error_level = error_level or ErrorLevel.RAISE + self.error_message_context = error_message_context + self.index_offset = index_offset + self.unnest_column_only = unnest_column_only + self.alias_post_tablesample = alias_post_tablesample + self.max_errors = max_errors + self.null_ordering = null_ordering + self.reset() + + def reset(self): + self.sql = "" + self.errors = [] + self._tokens = [] + self._chunks = [[]] + self._index = 0 + self._curr = None + self._next = None + self._prev = None + self._greedy_subqueries = False + + def parse(self, raw_tokens, sql=None): + """ + Parses the given list of tokens and returns a list of syntax trees, one tree + per parsed SQL statement. + + Args + raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`). + sql (str): the original SQL string. Used to produce helpful debug messages. + + Returns + the list of syntax trees (:class:`~sqlglot.expressions.Expression`). + """ + return self._parse( + parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql + ) + + def parse_into(self, expression_types, raw_tokens, sql=None): + for expression_type in ensure_list(expression_types): + parser = self.EXPRESSION_PARSERS.get(expression_type) + if not parser: + raise TypeError(f"No parser registered for {expression_type}") + try: + return self._parse(parser, raw_tokens, sql) + except ParseError as e: + error = e + raise ParseError(f"Failed to parse into {expression_types}") from error + + def _parse(self, parse_method, raw_tokens, sql=None): + self.reset() + self.sql = sql or "" + total = len(raw_tokens) + + for i, token in enumerate(raw_tokens): + if token.token_type == TokenType.SEMICOLON: + if i < total - 1: + self._chunks.append([]) + else: + self._chunks[-1].append(token) + + expressions = [] + + for tokens in self._chunks: + self._index = -1 + self._tokens = tokens + self._advance() + expressions.append(parse_method(self)) + + if self._index < len(self._tokens): + self.raise_error("Invalid expression / Unexpected token") + + self.check_errors() + + return expressions + + def check_errors(self): + if self.error_level == ErrorLevel.WARN: + for error in self.errors: + logger.error(str(error)) + elif self.error_level == ErrorLevel.RAISE and self.errors: + raise ParseError(concat_errors(self.errors, self.max_errors)) + + def raise_error(self, message, token=None): + token = token or self._curr or self._prev or Token.string("") + start = self._find_token(token, self.sql) + end = start + len(token.text) + start_context = self.sql[max(start - self.error_message_context, 0) : start] + highlight = self.sql[start:end] + end_context = self.sql[end : end + self.error_message_context] + error = ParseError( + f"{message}. Line {token.line}, Col: {token.col}.\n" + f" {start_context}\033[4m{highlight}\033[0m{end_context}" + ) + if self.error_level == ErrorLevel.IMMEDIATE: + raise error + self.errors.append(error) + + def expression(self, exp_class, **kwargs): + instance = exp_class(**kwargs) + self.validate_expression(instance) + return instance + + def validate_expression(self, expression, args=None): + if self.error_level == ErrorLevel.IGNORE: + return + + for k in expression.args: + if k not in expression.arg_types: + self.raise_error( + f"Unexpected keyword: '{k}' for {expression.__class__}" + ) + for k, mandatory in expression.arg_types.items(): + v = expression.args.get(k) + if mandatory and (v is None or (isinstance(v, list) and not v)): + self.raise_error( + f"Required keyword: '{k}' missing for {expression.__class__}" + ) + + if ( + args + and len(args) > len(expression.arg_types) + and not expression.is_var_len_args + ): + self.raise_error( + f"The number of provided arguments ({len(args)}) is greater than " + f"the maximum number of supported arguments ({len(expression.arg_types)})" + ) + + def _find_token(self, token, sql): + line = 1 + col = 1 + index = 0 + + while line < token.line or col < token.col: + if Tokenizer.WHITE_SPACE.get(sql[index]) == TokenType.BREAK: + line += 1 + col = 1 + else: + col += 1 + index += 1 + + return index + + def _get_token(self, index): + return list_get(self._tokens, index) + + def _advance(self, times=1): + self._index += times + self._curr = self._get_token(self._index) + self._next = self._get_token(self._index + 1) + self._prev = self._get_token(self._index - 1) if self._index > 0 else None + + def _retreat(self, index): + self._advance(index - self._index) + + def _parse_statement(self): + if self._curr is None: + return None + + if self._match_set(self.STATEMENT_PARSERS): + return self.STATEMENT_PARSERS[self._prev.token_type](self) + + if self._match_set(Tokenizer.COMMANDS): + return self.expression( + exp.Command, + this=self._prev.text, + expression=self._parse_string(), + ) + + expression = self._parse_expression() + expression = ( + self._parse_set_operations(expression) + if expression + else self._parse_select() + ) + self._parse_query_modifiers(expression) + return expression + + def _parse_drop(self): + if self._match(TokenType.TABLE): + kind = "TABLE" + elif self._match(TokenType.VIEW): + kind = "VIEW" + else: + self.raise_error("Expected TABLE or View") + + return self.expression( + exp.Drop, + exists=self._parse_exists(), + this=self._parse_table(schema=True), + kind=kind, + ) + + def _parse_exists(self, not_=False): + return ( + self._match(TokenType.IF) + and (not not_ or self._match(TokenType.NOT)) + and self._match(TokenType.EXISTS) + ) + + def _parse_create(self): + replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) + temporary = self._match(TokenType.TEMPORARY) + unique = self._match(TokenType.UNIQUE) + + create_token = self._match_set(self.CREATABLES) and self._prev + + if not create_token: + self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION") + + exists = self._parse_exists(not_=True) + this = None + expression = None + properties = None + + if create_token.token_type == TokenType.FUNCTION: + this = self._parse_var() + if self._match(TokenType.ALIAS): + expression = self._parse_string() + elif create_token.token_type == TokenType.INDEX: + this = self._parse_index() + elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): + this = self._parse_table(schema=True) + properties = self._parse_properties( + this if isinstance(this, exp.Schema) else None + ) + if self._match(TokenType.ALIAS): + expression = self._parse_select() + + return self.expression( + exp.Create, + this=this, + kind=create_token.text, + expression=expression, + exists=exists, + properties=properties, + temporary=temporary, + replace=replace, + unique=unique, + ) + + def _parse_property(self, schema): + if self._match_set(self.PROPERTY_PARSERS): + return self.PROPERTY_PARSERS[self._prev.token_type](self) + if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): + return self._parse_character_set(True) + + if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): + key = self._parse_var().this + self._match(TokenType.EQ) + + if key.upper() == "PARTITIONED_BY": + expression = exp.PartitionedByProperty + value = self._parse_schema() or self._parse_bracket(self._parse_field()) + + if schema and not isinstance(value, exp.Schema): + columns = {v.name.upper() for v in value.expressions} + partitions = [ + expression + for expression in schema.expressions + if expression.this.name.upper() in columns + ] + schema.set( + "expressions", + [e for e in schema.expressions if e not in partitions], + ) + value = self.expression(exp.Schema, expressions=partitions) + else: + value = self._parse_column() + expression = exp.AnonymousProperty + + return self.expression( + expression, + this=exp.Literal.string(key), + value=value, + ) + return None + + def _parse_stored(self): + self._match(TokenType.ALIAS) + self._match(TokenType.EQ) + return self.expression( + exp.FileFormatProperty, + this=exp.Literal.string("FORMAT"), + value=exp.Literal.string(self._parse_var().name), + ) + + def _parse_format(self): + self._match(TokenType.EQ) + return self.expression( + exp.FileFormatProperty, + this=exp.Literal.string("FORMAT"), + value=self._parse_string() or self._parse_var(), + ) + + def _parse_engine(self): + self._match(TokenType.EQ) + return self.expression( + exp.EngineProperty, + this=exp.Literal.string("ENGINE"), + value=self._parse_var_or_string(), + ) + + def _parse_auto_increment(self): + self._match(TokenType.EQ) + return self.expression( + exp.AutoIncrementProperty, + this=exp.Literal.string("AUTO_INCREMENT"), + value=self._parse_var() or self._parse_number(), + ) + + def _parse_collate(self): + self._match(TokenType.EQ) + return self.expression( + exp.CollateProperty, + this=exp.Literal.string("COLLATE"), + value=self._parse_var_or_string(), + ) + + def _parse_schema_comment(self): + self._match(TokenType.EQ) + return self.expression( + exp.SchemaCommentProperty, + this=exp.Literal.string("COMMENT"), + value=self._parse_string(), + ) + + def _parse_character_set(self, default=False): + self._match(TokenType.EQ) + return self.expression( + exp.CharacterSetProperty, + this=exp.Literal.string("CHARACTER_SET"), + value=self._parse_var_or_string(), + default=default, + ) + + def _parse_table_format(self): + self._match(TokenType.EQ) + return self.expression( + exp.TableFormatProperty, + this=exp.Literal.string("TABLE_FORMAT"), + value=self._parse_var_or_string(), + ) + + def _parse_properties(self, schema=None): + """ + Schema is included since if the table schema is defined and we later get a partition by expression + then we will define those columns in the partition by section and not in with the rest of the + columns + """ + properties = [] + + while True: + if self._match(TokenType.WITH): + self._match_l_paren() + properties.extend(self._parse_csv(lambda: self._parse_property(schema))) + self._match_r_paren() + elif self._match(TokenType.PROPERTIES): + self._match_l_paren() + properties.extend( + self._parse_csv( + lambda: self.expression( + exp.AnonymousProperty, + this=self._parse_string(), + value=self._match(TokenType.EQ) and self._parse_string(), + ) + ) + ) + self._match_r_paren() + else: + identified_property = self._parse_property(schema) + if not identified_property: + break + properties.append(identified_property) + if properties: + return self.expression(exp.Properties, expressions=properties) + return None + + def _parse_insert(self): + overwrite = self._match(TokenType.OVERWRITE) + self._match(TokenType.INTO) + self._match(TokenType.TABLE) + return self.expression( + exp.Insert, + this=self._parse_table(schema=True), + exists=self._parse_exists(), + partition=self._parse_partition(), + expression=self._parse_select(), + overwrite=overwrite, + ) + + def _parse_delete(self): + self._match(TokenType.FROM) + + return self.expression( + exp.Delete, + this=self._parse_table(schema=True), + where=self._parse_where(), + ) + + def _parse_update(self): + return self.expression( + exp.Update, + **{ + "this": self._parse_table(schema=True), + "expressions": self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + "from": self._parse_from(), + "where": self._parse_where(), + }, + ) + + def _parse_uncache(self): + if not self._match(TokenType.TABLE): + self.raise_error("Expecting TABLE after UNCACHE") + return self.expression( + exp.Uncache, + exists=self._parse_exists(), + this=self._parse_table(schema=True), + ) + + def _parse_cache(self): + lazy = self._match(TokenType.LAZY) + self._match(TokenType.TABLE) + table = self._parse_table(schema=True) + options = [] + + if self._match(TokenType.OPTIONS): + self._match_l_paren() + k = self._parse_string() + self._match(TokenType.EQ) + v = self._parse_string() + options = [k, v] + self._match_r_paren() + + self._match(TokenType.ALIAS) + return self.expression( + exp.Cache, + this=table, + lazy=lazy, + options=options, + expression=self._parse_select(), + ) + + def _parse_partition(self): + if not self._match(TokenType.PARTITION): + return None + + def parse_values(): + k = self._parse_var() + if self._match(TokenType.EQ): + v = self._parse_string() + return (k, v) + return (k, None) + + self._match_l_paren() + values = self._parse_csv(parse_values) + self._match_r_paren() + + return self.expression( + exp.Partition, + this=values, + ) + + def _parse_value(self): + self._match_l_paren() + expressions = self._parse_csv(self._parse_conjunction) + self._match_r_paren() + return self.expression(exp.Tuple, expressions=expressions) + + def _parse_select(self, table=None): + index = self._index + + if self._match(TokenType.SELECT): + hint = self._parse_hint() + all_ = self._match(TokenType.ALL) + distinct = self._match(TokenType.DISTINCT) + + if distinct: + distinct = self.expression( + exp.Distinct, + on=self._parse_value() if self._match(TokenType.ON) else None, + ) + + if all_ and distinct: + self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") + + limit = self._parse_limit(top=True) + expressions = self._parse_csv( + lambda: self._parse_annotation(self._parse_expression()) + ) + + this = self.expression( + exp.Select, + hint=hint, + distinct=distinct, + expressions=expressions, + limit=limit, + ) + from_ = self._parse_from() + if from_: + this.set("from", from_) + self._parse_query_modifiers(this) + elif self._match(TokenType.WITH): + recursive = self._match(TokenType.RECURSIVE) + + expressions = [] + + while True: + expressions.append(self._parse_cte()) + + if not self._match(TokenType.COMMA): + break + + cte = self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ) + this = self._parse_statement() + + if not this: + self.raise_error("Failed to parse any statement following CTE") + return cte + + if "with" in this.arg_types: + this.set( + "with", + self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ), + ) + else: + self.raise_error(f"{this.key} does not support CTE") + elif self._match(TokenType.L_PAREN): + this = self._parse_table() if table else self._parse_select() + + if this: + self._parse_query_modifiers(this) + self._match_r_paren() + this = self._parse_subquery(this) + else: + self._retreat(index) + elif self._match(TokenType.VALUES): + this = self.expression( + exp.Values, expressions=self._parse_csv(self._parse_value) + ) + alias = self._parse_table_alias() + if alias: + this = self.expression(exp.Subquery, this=this, alias=alias) + else: + this = None + + return self._parse_set_operations(this) if this else None + + def _parse_cte(self): + alias = self._parse_table_alias() + if not alias or not alias.this: + self.raise_error("Expected CTE to have alias") + + if not self._match(TokenType.ALIAS): + self.raise_error("Expected AS in CTE") + + self._match_l_paren() + expression = self._parse_statement() + self._match_r_paren() + + return self.expression( + exp.CTE, + this=expression, + alias=alias, + ) + + def _parse_table_alias(self): + any_token = self._match(TokenType.ALIAS) + alias = self._parse_id_var(any_token) + columns = None + + if self._match(TokenType.L_PAREN): + columns = self._parse_csv(lambda: self._parse_id_var(any_token)) + self._match_r_paren() + + if not alias and not columns: + return None + + return self.expression( + exp.TableAlias, + this=alias, + columns=columns, + ) + + def _parse_subquery(self, this): + return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) + + def _parse_query_modifiers(self, this): + if not isinstance(this, (exp.Subquery, exp.Subqueryable)): + return + + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): + expression = parser(self) + + if expression: + this.set(key, expression) + + def _parse_annotation(self, expression): + if self._match(TokenType.ANNOTATION): + return self.expression( + exp.Annotation, this=self._prev.text, expression=expression + ) + + return expression + + def _parse_hint(self): + if self._match(TokenType.HINT): + hints = self._parse_csv(self._parse_function) + if not self._match(TokenType.HINT): + self.raise_error("Expected */ after HINT") + return self.expression(exp.Hint, expressions=hints) + return None + + def _parse_from(self): + if not self._match(TokenType.FROM): + return None + + return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) + + def _parse_laterals(self): + return self._parse_all(self._parse_lateral) + + def _parse_lateral(self): + if not self._match(TokenType.LATERAL): + return None + + if not self._match(TokenType.VIEW): + self.raise_error("Expected VIEW after LATERAL") + + outer = self._match(TokenType.OUTER) + + return self.expression( + exp.Lateral, + this=self._parse_function(), + outer=outer, + alias=self.expression( + exp.TableAlias, + this=self._parse_id_var(any_token=False), + columns=( + self._parse_csv(self._parse_id_var) + if self._match(TokenType.ALIAS) + else None + ), + ), + ) + + def _parse_joins(self): + return self._parse_all(self._parse_join) + + def _parse_join_side_and_kind(self): + return ( + self._match_set(self.JOIN_SIDES) and self._prev, + self._match_set(self.JOIN_KINDS) and self._prev, + ) + + def _parse_join(self): + side, kind = self._parse_join_side_and_kind() + + if not self._match(TokenType.JOIN): + return None + + kwargs = {"this": self._parse_table()} + + if side: + kwargs["side"] = side.text + if kind: + kwargs["kind"] = kind.text + + if self._match(TokenType.ON): + kwargs["on"] = self._parse_conjunction() + elif self._match(TokenType.USING): + kwargs["using"] = self._parse_wrapped_id_vars() + + return self.expression(exp.Join, **kwargs) + + def _parse_index(self): + index = self._parse_id_var() + self._match(TokenType.ON) + self._match(TokenType.TABLE) # hive + return self.expression( + exp.Index, + this=index, + table=self.expression(exp.Table, this=self._parse_id_var()), + columns=self._parse_expression(), + ) + + def _parse_table(self, schema=False): + unnest = self._parse_unnest() + + if unnest: + return unnest + + subquery = self._parse_select(table=True) + + if subquery: + return subquery + + catalog = None + db = None + table = (not schema and self._parse_function()) or self._parse_id_var(False) + + while self._match(TokenType.DOT): + catalog = db + db = table + table = self._parse_id_var() + + if not table: + self.raise_error("Expected table name") + + this = self.expression(exp.Table, this=table, db=db, catalog=catalog) + + if schema: + return self._parse_schema(this=this) + + if self.alias_post_tablesample: + table_sample = self._parse_table_sample() + + alias = self._parse_table_alias() + + if alias: + this = self.expression(exp.Alias, this=this, alias=alias) + + if not self.alias_post_tablesample: + table_sample = self._parse_table_sample() + + if table_sample: + table_sample.set("this", this) + this = table_sample + + return this + + def _parse_unnest(self): + if not self._match(TokenType.UNNEST): + return None + + self._match_l_paren() + expressions = self._parse_csv(self._parse_column) + self._match_r_paren() + + ordinality = bool( + self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY) + ) + + alias = self._parse_table_alias() + + if alias and self.unnest_column_only: + if alias.args.get("columns"): + self.raise_error("Unexpected extra column alias in unnest.") + alias.set("columns", [alias.this]) + alias.set("this", None) + + return self.expression( + exp.Unnest, + expressions=expressions, + ordinality=ordinality, + alias=alias, + ) + + def _parse_table_sample(self): + if not self._match(TokenType.TABLE_SAMPLE): + return None + + method = self._parse_var() + bucket_numerator = None + bucket_denominator = None + bucket_field = None + percent = None + rows = None + size = None + + self._match_l_paren() + + if self._match(TokenType.BUCKET): + bucket_numerator = self._parse_number() + self._match(TokenType.OUT_OF) + bucket_denominator = bucket_denominator = self._parse_number() + self._match(TokenType.ON) + bucket_field = self._parse_field() + else: + num = self._parse_number() + + if self._match(TokenType.PERCENT): + percent = num + elif self._match(TokenType.ROWS): + rows = num + else: + size = num + + self._match_r_paren() + + return self.expression( + exp.TableSample, + method=method, + bucket_numerator=bucket_numerator, + bucket_denominator=bucket_denominator, + bucket_field=bucket_field, + percent=percent, + rows=rows, + size=size, + ) + + def _parse_where(self): + if not self._match(TokenType.WHERE): + return None + return self.expression(exp.Where, this=self._parse_conjunction()) + + def _parse_group(self): + if not self._match(TokenType.GROUP_BY): + return None + return self.expression( + exp.Group, + expressions=self._parse_csv(self._parse_conjunction), + grouping_sets=self._parse_grouping_sets(), + cube=self._match(TokenType.CUBE) and self._parse_wrapped_id_vars(), + rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(), + ) + + def _parse_grouping_sets(self): + if not self._match(TokenType.GROUPING_SETS): + return None + + self._match_l_paren() + grouping_sets = self._parse_csv(self._parse_grouping_set) + self._match_r_paren() + return grouping_sets + + def _parse_grouping_set(self): + if self._match(TokenType.L_PAREN): + grouping_set = self._parse_csv(self._parse_id_var) + self._match_r_paren() + return self.expression(exp.Tuple, expressions=grouping_set) + return self._parse_id_var() + + def _parse_having(self): + if not self._match(TokenType.HAVING): + return None + return self.expression(exp.Having, this=self._parse_conjunction()) + + def _parse_qualify(self): + if not self._match(TokenType.QUALIFY): + return None + return self.expression(exp.Qualify, this=self._parse_conjunction()) + + def _parse_order(self, this=None): + if not self._match(TokenType.ORDER_BY): + return this + + return self.expression( + exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) + ) + + def _parse_sort(self, token_type, exp_class): + if not self._match(token_type): + return None + + return self.expression( + exp_class, expressions=self._parse_csv(self._parse_ordered) + ) + + def _parse_ordered(self): + this = self._parse_conjunction() + self._match(TokenType.ASC) + is_desc = self._match(TokenType.DESC) + is_nulls_first = self._match(TokenType.NULLS_FIRST) + is_nulls_last = self._match(TokenType.NULLS_LAST) + desc = is_desc or False + asc = not desc + nulls_first = is_nulls_first or False + explicitly_null_ordered = is_nulls_first or is_nulls_last + if ( + not explicitly_null_ordered + and ( + (asc and self.null_ordering == "nulls_are_small") + or (desc and self.null_ordering != "nulls_are_small") + ) + and self.null_ordering != "nulls_are_last" + ): + nulls_first = True + + return self.expression( + exp.Ordered, this=this, desc=desc, nulls_first=nulls_first + ) + + def _parse_limit(self, this=None, top=False): + if self._match(TokenType.TOP if top else TokenType.LIMIT): + return self.expression( + exp.Limit, this=this, expression=self._parse_number() + ) + if self._match(TokenType.FETCH): + direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) + direction = self._prev.text if direction else "FIRST" + count = self._parse_number() + self._match_set((TokenType.ROW, TokenType.ROWS)) + self._match(TokenType.ONLY) + return self.expression(exp.Fetch, direction=direction, count=count) + return this + + def _parse_offset(self, this=None): + if not self._match(TokenType.OFFSET): + return this + count = self._parse_number() + self._match_set((TokenType.ROW, TokenType.ROWS)) + return self.expression(exp.Offset, this=this, expression=count) + + def _parse_set_operations(self, this): + if not self._match_set(self.SET_OPERATIONS): + return this + + token_type = self._prev.token_type + + if token_type == TokenType.UNION: + expression = exp.Union + elif token_type == TokenType.EXCEPT: + expression = exp.Except + else: + expression = exp.Intersect + + return self.expression( + expression, + this=this, + distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), + expression=self._parse_select(), + ) + + def _parse_expression(self): + return self._parse_alias(self._parse_conjunction()) + + def _parse_conjunction(self): + return self._parse_tokens(self._parse_equality, self.CONJUNCTION) + + def _parse_equality(self): + return self._parse_tokens(self._parse_comparison, self.EQUALITY) + + def _parse_comparison(self): + return self._parse_tokens(self._parse_range, self.COMPARISON) + + def _parse_range(self): + this = self._parse_bitwise() + negate = self._match(TokenType.NOT) + + if self._match_set(self.RANGE_PARSERS): + this = self.RANGE_PARSERS[self._prev.token_type](self, this) + + if negate: + this = self.expression(exp.Not, this=this) + + return this + + def _parse_is(self, this): + negate = self._match(TokenType.NOT) + this = self.expression( + exp.Is, + this=this, + expression=self._parse_null() or self._parse_boolean(), + ) + return self.expression(exp.Not, this=this) if negate else this + + def _parse_in(self, this): + unnest = self._parse_unnest() + if unnest: + this = self.expression(exp.In, this=this, unnest=unnest) + else: + self._match_l_paren() + expressions = self._parse_csv( + lambda: self._parse_select() or self._parse_expression() + ) + + if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): + this = self.expression(exp.In, this=this, query=expressions[0]) + else: + this = self.expression(exp.In, this=this, expressions=expressions) + + self._match_r_paren() + return this + + def _parse_between(self, this): + low = self._parse_bitwise() + self._match(TokenType.AND) + high = self._parse_bitwise() + return self.expression(exp.Between, this=this, low=low, high=high) + + def _parse_escape(self, this): + if not self._match(TokenType.ESCAPE): + return this + return self.expression(exp.Escape, this=this, expression=self._parse_string()) + + def _parse_bitwise(self): + this = self._parse_term() + + while True: + if self._match_set(self.BITWISE): + this = self.expression( + self.BITWISE[self._prev.token_type], + this=this, + expression=self._parse_term(), + ) + elif self._match_pair(TokenType.LT, TokenType.LT): + this = self.expression( + exp.BitwiseLeftShift, this=this, expression=self._parse_term() + ) + elif self._match_pair(TokenType.GT, TokenType.GT): + this = self.expression( + exp.BitwiseRightShift, this=this, expression=self._parse_term() + ) + else: + break + + return this + + def _parse_term(self): + return self._parse_tokens(self._parse_factor, self.TERM) + + def _parse_factor(self): + return self._parse_tokens(self._parse_unary, self.FACTOR) + + def _parse_unary(self): + if self._match(TokenType.NOT): + return self.expression(exp.Not, this=self._parse_equality()) + if self._match(TokenType.TILDA): + return self.expression(exp.BitwiseNot, this=self._parse_unary()) + if self._match(TokenType.DASH): + return self.expression(exp.Neg, this=self._parse_unary()) + return self._parse_at_time_zone(self._parse_type()) + + def _parse_type(self): + if self._match(TokenType.INTERVAL): + return self.expression( + exp.Interval, + this=self._parse_term(), + unit=self._parse_var(), + ) + + index = self._index + type_token = self._parse_types() + this = self._parse_column() + + if type_token: + if this: + return self.expression(exp.Cast, this=this, to=type_token) + if not type_token.args.get("expressions"): + self._retreat(index) + return self._parse_column() + return type_token + + while self._match(TokenType.DCOLON): + type_token = self._parse_types() + if not type_token: + self.raise_error("Expected type") + this = self.expression(exp.Cast, this=this, to=type_token) + + return this + + def _parse_types(self): + index = self._index + + if not self._match_set(self.TYPE_TOKENS): + return None + + type_token = self._prev.token_type + nested = type_token in self.NESTED_TYPE_TOKENS + is_struct = type_token == TokenType.STRUCT + expressions = None + + if self._match(TokenType.L_BRACKET): + self._retreat(index) + return None + + if self._match(TokenType.L_PAREN): + if is_struct: + expressions = self._parse_csv(self._parse_struct_kwargs) + elif nested: + expressions = self._parse_csv(self._parse_types) + else: + expressions = self._parse_csv(self._parse_number) + + if not expressions: + self._retreat(index) + return None + + self._match_r_paren() + + if nested and self._match(TokenType.LT): + if is_struct: + expressions = self._parse_csv(self._parse_struct_kwargs) + else: + expressions = self._parse_csv(self._parse_types) + + if not self._match(TokenType.GT): + self.raise_error("Expecting >") + + if type_token in self.TIMESTAMPS: + tz = self._match(TokenType.WITH_TIME_ZONE) + self._match(TokenType.WITHOUT_TIME_ZONE) + if tz: + return exp.DataType( + this=exp.DataType.Type.TIMESTAMPTZ, + expressions=expressions, + ) + return exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) + + return exp.DataType( + this=exp.DataType.Type[type_token.value.upper()], + expressions=expressions, + nested=nested, + ) + + def _parse_struct_kwargs(self): + this = self._parse_id_var() + self._match(TokenType.COLON) + data_type = self._parse_types() + if not data_type: + return None + return self.expression(exp.StructKwarg, this=this, expression=data_type) + + def _parse_at_time_zone(self, this): + if not self._match(TokenType.AT_TIME_ZONE): + return this + + return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) + + def _parse_column(self): + this = self._parse_field() + if isinstance(this, exp.Identifier): + this = self.expression(exp.Column, this=this) + elif not this: + return self._parse_bracket(this) + this = self._parse_bracket(this) + + while self._match_set(self.COLUMN_OPERATORS): + op = self.COLUMN_OPERATORS.get(self._prev.token_type) + field = self._parse_star() or self._parse_function() or self._parse_id_var() + + if isinstance(field, exp.Func): + # bigquery allows function calls like x.y.count(...) + # SAFE.SUBSTR(...) + # https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules + this = self._replace_columns_with_dots(this) + + if op: + this = op(self, this, exp.Literal.string(field.name)) + elif isinstance(this, exp.Column) and not this.table: + this = self.expression(exp.Column, this=field, table=this.this) + else: + this = self.expression(exp.Dot, this=this, expression=field) + this = self._parse_bracket(this) + + return this + + def _parse_primary(self): + if self._match_set(self.PRIMARY_PARSERS): + return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + + if self._match(TokenType.L_PAREN): + query = self._parse_select() + + if query: + expressions = [query] + else: + expressions = self._parse_csv( + lambda: self._parse_alias(self._parse_conjunction(), explicit=True) + ) + + this = list_get(expressions, 0) + self._parse_query_modifiers(this) + self._match_r_paren() + + if isinstance(this, exp.Subqueryable): + return self._parse_subquery(this) + if len(expressions) > 1: + return self.expression(exp.Tuple, expressions=expressions) + return self.expression(exp.Paren, this=this) + + return None + + def _parse_field(self, any_token=False): + return ( + self._parse_primary() + or self._parse_function() + or self._parse_id_var(any_token) + ) + + def _parse_function(self): + if not self._curr: + return None + + token_type = self._curr.token_type + + if self._match_set(self.NO_PAREN_FUNCTION_PARSERS): + return self.NO_PAREN_FUNCTION_PARSERS[token_type](self) + + if not self._next or self._next.token_type != TokenType.L_PAREN: + if token_type in self.NO_PAREN_FUNCTIONS: + return self.expression( + self._advance() or self.NO_PAREN_FUNCTIONS[token_type] + ) + return None + + if token_type not in self.FUNC_TOKENS: + return None + + if self._match_set(self.FUNCTION_PARSERS): + self._advance() + this = self.FUNCTION_PARSERS[token_type](self, token_type) + else: + subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) + this = self._curr.text + self._advance(2) + + if subquery_predicate and self._curr.token_type in ( + TokenType.SELECT, + TokenType.WITH, + ): + this = self.expression(subquery_predicate, this=self._parse_select()) + self._match_r_paren() + return this + + function = self.FUNCTIONS.get(this.upper()) + args = self._parse_csv(self._parse_lambda) + + if function: + this = function(args) + self.validate_expression(this, args) + else: + this = self.expression(exp.Anonymous, this=this, expressions=args) + self._match_r_paren() + return self._parse_window(this) + + def _parse_lambda(self): + index = self._index + + if self._match(TokenType.L_PAREN): + expressions = self._parse_csv(self._parse_id_var) + self._match(TokenType.R_PAREN) + else: + expressions = [self._parse_id_var()] + + if not self._match(TokenType.ARROW): + self._retreat(index) + + distinct = self._match(TokenType.DISTINCT) + this = self._parse_conjunction() + + if distinct: + this = self.expression(exp.Distinct, this=this) + + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + else: + self._match(TokenType.RESPECT_NULLS) + + return self._parse_alias(self._parse_limit(self._parse_order(this))) + + return self.expression( + exp.Lambda, + this=self._parse_conjunction(), + expressions=expressions, + ) + + def _parse_schema(self, this=None): + index = self._index + if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): + self._retreat(index) + return this + + args = self._parse_csv( + lambda: self._parse_constraint() + or self._parse_column_def(self._parse_field()) + ) + self._match_r_paren() + return self.expression(exp.Schema, this=this, expressions=args) + + def _parse_column_def(self, this): + kind = self._parse_types() + + if not kind: + return this + + constraints = [] + while True: + constraint = self._parse_column_constraint() + if not constraint: + break + constraints.append(constraint) + + return self.expression( + exp.ColumnDef, this=this, kind=kind, constraints=constraints + ) + + def _parse_column_constraint(self): + kind = None + this = None + + if self._match(TokenType.CONSTRAINT): + this = self._parse_id_var() + + if self._match(TokenType.AUTO_INCREMENT): + kind = exp.AutoIncrementColumnConstraint() + elif self._match(TokenType.CHECK): + self._match_l_paren() + kind = self.expression( + exp.CheckColumnConstraint, this=self._parse_conjunction() + ) + self._match_r_paren() + elif self._match(TokenType.COLLATE): + kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) + elif self._match(TokenType.DEFAULT): + kind = self.expression( + exp.DefaultColumnConstraint, this=self._parse_field() + ) + elif self._match(TokenType.NOT) and self._match(TokenType.NULL): + kind = exp.NotNullColumnConstraint() + elif self._match(TokenType.SCHEMA_COMMENT): + kind = self.expression( + exp.CommentColumnConstraint, this=self._parse_string() + ) + elif self._match(TokenType.PRIMARY_KEY): + kind = exp.PrimaryKeyColumnConstraint() + elif self._match(TokenType.UNIQUE): + kind = exp.UniqueColumnConstraint() + + if kind is None: + return None + + return self.expression(exp.ColumnConstraint, this=this, kind=kind) + + def _parse_constraint(self): + if not self._match(TokenType.CONSTRAINT): + return self._parse_unnamed_constraint() + + this = self._parse_id_var() + expressions = [] + + while True: + constraint = self._parse_unnamed_constraint() or self._parse_function() + if not constraint: + break + expressions.append(constraint) + + return self.expression(exp.Constraint, this=this, expressions=expressions) + + def _parse_unnamed_constraint(self): + if not self._match_set(self.CONSTRAINT_PARSERS): + return None + + return self.CONSTRAINT_PARSERS[self._prev.token_type](self) + + def _parse_check(self): + self._match(TokenType.CHECK) + self._match_l_paren() + expression = self._parse_conjunction() + self._match_r_paren() + + return self.expression(exp.Check, this=expression) + + def _parse_unique(self): + self._match(TokenType.UNIQUE) + columns = self._parse_wrapped_id_vars() + + return self.expression(exp.Unique, expressions=columns) + + def _parse_foreign_key(self): + self._match(TokenType.FOREIGN_KEY) + + expressions = self._parse_wrapped_id_vars() + reference = self._match(TokenType.REFERENCES) and self.expression( + exp.Reference, + this=self._parse_id_var(), + expressions=self._parse_wrapped_id_vars(), + ) + options = {} + + while self._match(TokenType.ON): + if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): + self.raise_error("Expected DELETE or UPDATE") + kind = self._prev.text.lower() + + if self._match(TokenType.NO_ACTION): + action = "NO ACTION" + elif self._match(TokenType.SET): + self._match_set((TokenType.NULL, TokenType.DEFAULT)) + action = "SET " + self._prev.text.upper() + else: + self._advance() + action = self._prev.text.upper() + options[kind] = action + + return self.expression( + exp.ForeignKey, + expressions=expressions, + reference=reference, + **options, + ) + + def _parse_bracket(self, this): + if not self._match(TokenType.L_BRACKET): + return this + + expressions = self._parse_csv(self._parse_conjunction) + + if not this or this.name.upper() == "ARRAY": + this = self.expression(exp.Array, expressions=expressions) + else: + expressions = apply_index_offset(expressions, -self.index_offset) + this = self.expression(exp.Bracket, this=this, expressions=expressions) + + if not self._match(TokenType.R_BRACKET): + self.raise_error("Expected ]") + + return self._parse_bracket(this) + + def _parse_case(self): + ifs = [] + default = None + + expression = self._parse_conjunction() + + while self._match(TokenType.WHEN): + this = self._parse_conjunction() + self._match(TokenType.THEN) + then = self._parse_conjunction() + ifs.append(self.expression(exp.If, this=this, true=then)) + + if self._match(TokenType.ELSE): + default = self._parse_conjunction() + + if not self._match(TokenType.END): + self.raise_error("Expected END after CASE", self._prev) + + return self._parse_window( + self.expression(exp.Case, this=expression, ifs=ifs, default=default) + ) + + def _parse_if(self): + if self._match(TokenType.L_PAREN): + args = self._parse_csv(self._parse_conjunction) + this = exp.If.from_arg_list(args) + self.validate_expression(this, args) + self._match_r_paren() + else: + condition = self._parse_conjunction() + self._match(TokenType.THEN) + true = self._parse_conjunction() + false = self._parse_conjunction() if self._match(TokenType.ELSE) else None + self._match(TokenType.END) + this = self.expression(exp.If, this=condition, true=true, false=false) + return self._parse_window(this) + + def _parse_extract(self): + this = self._parse_var() or self._parse_type() + + if not self._match(TokenType.FROM): + self.raise_error("Expected FROM after EXTRACT", self._prev) + + return self.expression(exp.Extract, this=this, expression=self._parse_type()) + + def _parse_cast(self, strict): + this = self._parse_conjunction() + + if not self._match(TokenType.ALIAS): + self.raise_error("Expected AS after CAST") + + to = self._parse_types() + + if not to: + self.raise_error("Expected TYPE after CAST") + elif to.this == exp.DataType.Type.CHAR: + if self._match(TokenType.CHARACTER_SET): + to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) + + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + + def _parse_convert(self): + this = self._parse_field() + if self._match(TokenType.USING): + to = self.expression(exp.CharacterSet, this=self._parse_var()) + elif self._match(TokenType.COMMA): + to = self._parse_types() + else: + to = None + return self.expression(exp.Cast, this=this, to=to) + + def _parse_window(self, this, alias=False): + if self._match(TokenType.FILTER): + self._match_l_paren() + this = self.expression( + exp.Filter, this=this, expression=self._parse_where() + ) + self._match_r_paren() + + if self._match(TokenType.WITHIN_GROUP): + self._match_l_paren() + this = self.expression( + exp.WithinGroup, + this=this, + expression=self._parse_order(), + ) + self._match_r_paren() + return this + + # bigquery select from window x AS (partition by ...) + if alias: + self._match(TokenType.ALIAS) + elif not self._match(TokenType.OVER): + return this + + if not self._match(TokenType.L_PAREN): + alias = self._parse_id_var(False) + + return self.expression( + exp.Window, + this=this, + alias=alias, + ) + + partition = None + + alias = self._parse_id_var(False) + + if self._match(TokenType.PARTITION_BY): + partition = self._parse_csv(self._parse_conjunction) + + order = self._parse_order() + + spec = None + kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text + + if kind: + self._match(TokenType.BETWEEN) + start = self._parse_window_spec() + self._match(TokenType.AND) + end = self._parse_window_spec() + + spec = self.expression( + exp.WindowSpec, + kind=kind, + start=start["value"], + start_side=start["side"], + end=end["value"], + end_side=end["side"], + ) + + self._match_r_paren() + + return self.expression( + exp.Window, + this=this, + partition_by=partition, + order=order, + spec=spec, + alias=alias, + ) + + def _parse_window_spec(self): + self._match(TokenType.BETWEEN) + + return { + "value": ( + self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) + and self._prev.text + ) + or self._parse_bitwise(), + "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) + and self._prev.text, + } + + def _parse_alias(self, this, explicit=False): + any_token = self._match(TokenType.ALIAS) + + if explicit and not any_token: + return this + + if self._match(TokenType.L_PAREN): + aliases = self.expression( + exp.Aliases, + this=this, + expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), + ) + self._match_r_paren() + return aliases + + alias = self._parse_id_var(any_token) + + if alias: + return self.expression(exp.Alias, this=this, alias=alias) + + return this + + def _parse_id_var(self, any_token=True): + identifier = self._parse_identifier() + + if identifier: + return identifier + + if ( + any_token + and self._curr + and self._curr.token_type not in self.RESERVED_KEYWORDS + ): + return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) + + return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier( + this=self._prev.text, quoted=False + ) + + def _parse_string(self): + if self._match(TokenType.STRING): + return exp.Literal.string(self._prev.text) + return self._parse_placeholder() + + def _parse_number(self): + if self._match(TokenType.NUMBER): + return exp.Literal.number(self._prev.text) + return self._parse_placeholder() + + def _parse_identifier(self): + if self._match(TokenType.IDENTIFIER): + return exp.Identifier(this=self._prev.text, quoted=True) + return self._parse_placeholder() + + def _parse_var(self): + if self._match(TokenType.VAR): + return exp.Var(this=self._prev.text) + return self._parse_placeholder() + + def _parse_var_or_string(self): + return self._parse_var() or self._parse_string() + + def _parse_null(self): + if self._match(TokenType.NULL): + return exp.Null() + return None + + def _parse_boolean(self): + if self._match(TokenType.TRUE): + return exp.Boolean(this=True) + if self._match(TokenType.FALSE): + return exp.Boolean(this=False) + return None + + def _parse_star(self): + if self._match(TokenType.STAR): + return exp.Star( + **{"except": self._parse_except(), "replace": self._parse_replace()} + ) + return None + + def _parse_placeholder(self): + if self._match(TokenType.PLACEHOLDER): + return exp.Placeholder() + return None + + def _parse_except(self): + if not self._match(TokenType.EXCEPT): + return None + + return self._parse_wrapped_id_vars() + + def _parse_replace(self): + if not self._match(TokenType.REPLACE): + return None + + self._match_l_paren() + columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression())) + self._match_r_paren() + return columns + + def _parse_csv(self, parse): + parse_result = parse() + items = [parse_result] if parse_result is not None else [] + + while self._match(TokenType.COMMA): + parse_result = parse() + if parse_result is not None: + items.append(parse_result) + + return items + + def _parse_tokens(self, parse, expressions): + this = parse() + + while self._match_set(expressions): + this = self.expression( + expressions[self._prev.token_type], this=this, expression=parse() + ) + + return this + + def _parse_all(self, parse): + return list(iter(parse, None)) + + def _parse_wrapped_id_vars(self): + self._match_l_paren() + expressions = self._parse_csv(self._parse_id_var) + self._match_r_paren() + return expressions + + def _match(self, token_type): + if not self._curr: + return None + + if self._curr.token_type == token_type: + self._advance() + return True + + return None + + def _match_set(self, types): + if not self._curr: + return None + + if self._curr.token_type in types: + self._advance() + return True + + return None + + def _match_pair(self, token_type_a, token_type_b, advance=True): + if not self._curr or not self._next: + return None + + if ( + self._curr.token_type == token_type_a + and self._next.token_type == token_type_b + ): + if advance: + self._advance(2) + return True + + return None + + def _match_l_paren(self): + if not self._match(TokenType.L_PAREN): + self.raise_error("Expecting (") + + def _match_r_paren(self): + if not self._match(TokenType.R_PAREN): + self.raise_error("Expecting )") + + def _replace_columns_with_dots(self, this): + if isinstance(this, exp.Dot): + exp.replace_children(this, self._replace_columns_with_dots) + elif isinstance(this, exp.Column): + exp.replace_children(this, self._replace_columns_with_dots) + table = this.args.get("table") + this = ( + self.expression(exp.Dot, this=table, expression=this.this) + if table + else self.expression(exp.Var, this=this.name) + ) + elif isinstance(this, exp.Identifier): + this = self.expression(exp.Var, this=this.name) + return this diff --git a/sqlglot/planner.py b/sqlglot/planner.py new file mode 100644 index 0000000..2006a75 --- /dev/null +++ b/sqlglot/planner.py @@ -0,0 +1,340 @@ +import itertools +import math + +from sqlglot import alias, exp +from sqlglot.errors import UnsupportedError +from sqlglot.optimizer.simplify import simplify + + +class Plan: + def __init__(self, expression): + self.expression = expression + self.root = Step.from_expression(self.expression) + self._dag = {} + + @property + def dag(self): + if not self._dag: + dag = {} + nodes = {self.root} + + while nodes: + node = nodes.pop() + dag[node] = set() + for dep in node.dependencies: + dag[node].add(dep) + nodes.add(dep) + self._dag = dag + + return self._dag + + @property + def leaves(self): + return (node for node, deps in self.dag.items() if not deps) + + +class Step: + @classmethod + def from_expression(cls, expression, ctes=None): + """ + Build a DAG of Steps from a SQL expression. + + Giving an expression like: + + SELECT x.a, SUM(x.b) + FROM x + JOIN y + ON x.a = y.a + GROUP BY x.a + + Transform it into a DAG of the form: + + Aggregate(x.a, SUM(x.b)) + Join(y) + Scan(x) + Scan(y) + + This can then more easily be executed on by an engine. + """ + ctes = ctes or {} + with_ = expression.args.get("with") + + # CTEs break the mold of scope and introduce themselves to all in the context. + if with_: + ctes = ctes.copy() + for cte in with_.expressions: + step = Step.from_expression(cte.this, ctes) + step.name = cte.alias + ctes[step.name] = step + + from_ = expression.args.get("from") + + if from_: + from_ = from_.expressions + if len(from_) > 1: + raise UnsupportedError( + "Multi-from statements are unsupported. Run it through the optimizer" + ) + + step = Scan.from_expression(from_[0], ctes) + else: + raise UnsupportedError("Static selects are unsupported.") + + joins = expression.args.get("joins") + + if joins: + join = Join.from_joins(joins, ctes) + join.name = step.name + join.add_dependency(step) + step = join + + projections = [] # final selects in this chain of steps representing a select + operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) + aggregations = [] + sequence = itertools.count() + + for e in expression.expressions: + aggregation = e.find(exp.AggFunc) + + if aggregation: + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + aggregations.append(e) + for operand in aggregation.unnest_operands(): + if isinstance(operand, exp.Column): + continue + if operand not in operands: + operands[operand] = f"_a_{next(sequence)}" + operand.replace( + exp.column(operands[operand], step.name, quoted=True) + ) + else: + projections.append(e) + + where = expression.args.get("where") + + if where: + step.condition = where.this + + group = expression.args.get("group") + + if group: + aggregate = Aggregate() + aggregate.source = step.name + aggregate.name = step.name + aggregate.operands = tuple( + alias(operand, alias_) for operand, alias_ in operands.items() + ) + aggregate.aggregations = aggregations + aggregate.group = [ + exp.column(e.alias_or_name, step.name, quoted=True) + for e in group.expressions + ] + aggregate.add_dependency(step) + step = aggregate + + having = expression.args.get("having") + + if having: + step.condition = having.this + + order = expression.args.get("order") + + if order: + sort = Sort() + sort.name = step.name + sort.key = order.expressions + sort.add_dependency(step) + step = sort + for k in sort.key + projections: + for column in k.find_all(exp.Column): + column.set("table", exp.to_identifier(step.name, quoted=True)) + + step.projections = projections + + limit = expression.args.get("limit") + + if limit: + step.limit = int(limit.text("expression")) + + return step + + def __init__(self): + self.name = None + self.dependencies = set() + self.dependents = set() + self.projections = [] + self.limit = math.inf + self.condition = None + + def add_dependency(self, dependency): + self.dependencies.add(dependency) + dependency.dependents.add(self) + + def __repr__(self): + return self.to_s() + + def to_s(self, level=0): + indent = " " * level + nested = f"{indent} " + + context = self._to_s(f"{nested} ") + + if context: + context = [f"{nested}Context:"] + context + + lines = [ + f"{indent}- {self.__class__.__name__}: {self.name}", + *context, + f"{nested}Projections:", + ] + + for expression in self.projections: + lines.append(f"{nested} - {expression.sql()}") + + if self.condition: + lines.append(f"{nested}Condition: {self.condition.sql()}") + + if self.dependencies: + lines.append(f"{nested}Dependencies:") + for dependency in self.dependencies: + lines.append(" " + dependency.to_s(level + 1)) + + return "\n".join(lines) + + def _to_s(self, _indent): + return [] + + +class Scan(Step): + @classmethod + def from_expression(cls, expression, ctes=None): + table = expression.this + alias_ = expression.alias + + if not alias_: + raise UnsupportedError( + "Tables/Subqueries must be aliased. Run it through the optimizer" + ) + + if isinstance(expression, exp.Subquery): + step = Step.from_expression(table, ctes) + step.name = alias_ + return step + + step = Scan() + step.name = alias_ + step.source = expression + if table.name in ctes: + step.add_dependency(ctes[table.name]) + + return step + + def __init__(self): + super().__init__() + self.source = None + + def _to_s(self, indent): + return [f"{indent}Source: {self.source.sql()}"] + + +class Write(Step): + pass + + +class Join(Step): + @classmethod + def from_joins(cls, joins, ctes=None): + step = Join() + + for join in joins: + name = join.this.alias + on = join.args.get("on") or exp.TRUE + source_key = [] + join_key = [] + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + for condition in on.flatten() if isinstance(on, exp.And) else [on]: + if isinstance(condition, exp.EQ): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.TRUE) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.TRUE) + + on = simplify(on) + + step.joins[name] = { + "side": join.side, + "join_key": join_key, + "source_key": source_key, + "condition": None if on == exp.TRUE else on, + } + + step.add_dependency(Scan.from_expression(join.this, ctes)) + + return step + + def __init__(self): + super().__init__() + self.joins = {} + + def _to_s(self, indent): + lines = [] + for name, join in self.joins.items(): + lines.append(f"{indent}{name}: {join['side']}") + if join.get("condition"): + lines.append(f"{indent}On: {join['condition'].sql()}") + return lines + + +class Aggregate(Step): + def __init__(self): + super().__init__() + self.aggregations = [] + self.operands = [] + self.group = [] + self.source = None + + def _to_s(self, indent): + lines = [f"{indent}Aggregations:"] + + for expression in self.aggregations: + lines.append(f"{indent} - {expression.sql()}") + + if self.group: + lines.append(f"{indent}Group:") + for expression in self.group: + lines.append(f"{indent} - {expression.sql()}") + if self.operands: + lines.append(f"{indent}Operands:") + for expression in self.operands: + lines.append(f"{indent} - {expression.sql()}") + + return lines + + +class Sort(Step): + def __init__(self): + super().__init__() + self.key = None + + def _to_s(self, indent): + lines = [f"{indent}Key:"] + + for expression in self.key: + lines.append(f"{indent} - {expression.sql()}") + + return lines diff --git a/sqlglot/time.py b/sqlglot/time.py new file mode 100644 index 0000000..16314c5 --- /dev/null +++ b/sqlglot/time.py @@ -0,0 +1,45 @@ +# the generic time format is based on python time.strftime +# https://docs.python.org/3/library/time.html#time.strftime +from sqlglot.trie import in_trie, new_trie + + +def format_time(string, mapping, trie=None): + """ + Converts a time string given a mapping. + + Examples: + >>> format_time("%Y", {"%Y": "YYYY"}) + 'YYYY' + + mapping: Dictionary of time format to target time format + trie: Optional trie, can be passed in for performance + """ + start = 0 + end = 1 + size = len(string) + trie = trie or new_trie(mapping) + current = trie + chunks = [] + sym = None + + while end <= size: + chars = string[start:end] + result, current = in_trie(current, chars[-1]) + + if result == 0: + if sym: + end -= 1 + chars = sym + sym = None + start += len(chars) + chunks.append(chars) + current = trie + elif result == 2: + sym = chars + + end += 1 + + if result and end > size: + chunks.append(chars) + + return "".join(mapping.get(chars, chars) for chars in chunks) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py new file mode 100644 index 0000000..e4b754d --- /dev/null +++ b/sqlglot/tokens.py @@ -0,0 +1,853 @@ +from enum import auto + +from sqlglot.helper import AutoName +from sqlglot.trie import in_trie, new_trie + + +class TokenType(AutoName): + L_PAREN = auto() + R_PAREN = auto() + L_BRACKET = auto() + R_BRACKET = auto() + L_BRACE = auto() + R_BRACE = auto() + COMMA = auto() + DOT = auto() + DASH = auto() + PLUS = auto() + COLON = auto() + DCOLON = auto() + SEMICOLON = auto() + STAR = auto() + SLASH = auto() + LT = auto() + LTE = auto() + GT = auto() + GTE = auto() + NOT = auto() + EQ = auto() + NEQ = auto() + AND = auto() + OR = auto() + AMP = auto() + DPIPE = auto() + PIPE = auto() + CARET = auto() + TILDA = auto() + ARROW = auto() + DARROW = auto() + HASH_ARROW = auto() + DHASH_ARROW = auto() + ANNOTATION = auto() + DOLLAR = auto() + + SPACE = auto() + BREAK = auto() + + STRING = auto() + NUMBER = auto() + IDENTIFIER = auto() + COLUMN = auto() + COLUMN_DEF = auto() + SCHEMA = auto() + TABLE = auto() + VAR = auto() + BIT_STRING = auto() + + # types + BOOLEAN = auto() + TINYINT = auto() + SMALLINT = auto() + INT = auto() + BIGINT = auto() + FLOAT = auto() + DOUBLE = auto() + DECIMAL = auto() + CHAR = auto() + NCHAR = auto() + VARCHAR = auto() + NVARCHAR = auto() + TEXT = auto() + BINARY = auto() + BYTEA = auto() + JSON = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + DATETIME = auto() + DATE = auto() + UUID = auto() + GEOGRAPHY = auto() + NULLABLE = auto() + + # keywords + ADD_FILE = auto() + ALIAS = auto() + ALL = auto() + ALTER = auto() + ANALYZE = auto() + ANY = auto() + ARRAY = auto() + ASC = auto() + AT_TIME_ZONE = auto() + AUTO_INCREMENT = auto() + BEGIN = auto() + BETWEEN = auto() + BUCKET = auto() + CACHE = auto() + CALL = auto() + CASE = auto() + CAST = auto() + CHARACTER_SET = auto() + CHECK = auto() + CLUSTER_BY = auto() + COLLATE = auto() + COMMENT = auto() + COMMIT = auto() + CONSTRAINT = auto() + CONVERT = auto() + CREATE = auto() + CROSS = auto() + CUBE = auto() + CURRENT_DATE = auto() + CURRENT_DATETIME = auto() + CURRENT_ROW = auto() + CURRENT_TIME = auto() + CURRENT_TIMESTAMP = auto() + DIV = auto() + DEFAULT = auto() + DELETE = auto() + DESC = auto() + DISTINCT = auto() + DISTRIBUTE_BY = auto() + DROP = auto() + ELSE = auto() + END = auto() + ENGINE = auto() + ESCAPE = auto() + EXCEPT = auto() + EXISTS = auto() + EXPLAIN = auto() + EXTRACT = auto() + FALSE = auto() + FETCH = auto() + FILTER = auto() + FINAL = auto() + FIRST = auto() + FOLLOWING = auto() + FOREIGN_KEY = auto() + FORMAT = auto() + FULL = auto() + FUNCTION = auto() + FROM = auto() + GROUP_BY = auto() + GROUPING_SETS = auto() + HAVING = auto() + HINT = auto() + IF = auto() + IGNORE_NULLS = auto() + ILIKE = auto() + IN = auto() + INDEX = auto() + INNER = auto() + INSERT = auto() + INTERSECT = auto() + INTERVAL = auto() + INTO = auto() + INTRODUCER = auto() + IS = auto() + ISNULL = auto() + JOIN = auto() + LATERAL = auto() + LAZY = auto() + LEFT = auto() + LIKE = auto() + LIMIT = auto() + LOCATION = auto() + MAP = auto() + MOD = auto() + NEXT = auto() + NO_ACTION = auto() + NULL = auto() + NULLS_FIRST = auto() + NULLS_LAST = auto() + OFFSET = auto() + ON = auto() + ONLY = auto() + OPTIMIZE = auto() + OPTIONS = auto() + ORDER_BY = auto() + ORDERED = auto() + ORDINALITY = auto() + OUTER = auto() + OUT_OF = auto() + OVER = auto() + OVERWRITE = auto() + PARTITION = auto() + PARTITION_BY = auto() + PARTITIONED_BY = auto() + PERCENT = auto() + PLACEHOLDER = auto() + PRECEDING = auto() + PRIMARY_KEY = auto() + PROPERTIES = auto() + QUALIFY = auto() + QUOTE = auto() + RANGE = auto() + RECURSIVE = auto() + REPLACE = auto() + RESPECT_NULLS = auto() + REFERENCES = auto() + RIGHT = auto() + RLIKE = auto() + ROLLUP = auto() + ROW = auto() + ROWS = auto() + SCHEMA_COMMENT = auto() + SELECT = auto() + SET = auto() + SHOW = auto() + SOME = auto() + SORT_BY = auto() + STORED = auto() + STRUCT = auto() + TABLE_FORMAT = auto() + TABLE_SAMPLE = auto() + TEMPORARY = auto() + TIME = auto() + TOP = auto() + THEN = auto() + TRUE = auto() + TRUNCATE = auto() + TRY_CAST = auto() + UNBOUNDED = auto() + UNCACHE = auto() + UNION = auto() + UNNEST = auto() + UPDATE = auto() + USE = auto() + USING = auto() + VALUES = auto() + VIEW = auto() + WHEN = auto() + WHERE = auto() + WINDOW = auto() + WITH = auto() + WITH_TIME_ZONE = auto() + WITHIN_GROUP = auto() + WITHOUT_TIME_ZONE = auto() + UNIQUE = auto() + + +class Token: + __slots__ = ("token_type", "text", "line", "col") + + @classmethod + def number(cls, number): + return cls(TokenType.NUMBER, str(number)) + + @classmethod + def string(cls, string): + return cls(TokenType.STRING, string) + + @classmethod + def identifier(cls, identifier): + return cls(TokenType.IDENTIFIER, identifier) + + @classmethod + def var(cls, var): + return cls(TokenType.VAR, var) + + def __init__(self, token_type, text, line=1, col=1): + self.token_type = token_type + self.text = text + self.line = line + self.col = max(col - len(text), 1) + + def __repr__(self): + attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) + return f"" + + +class _Tokenizer(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + klass.QUOTES = dict( + (quote, quote) if isinstance(quote, str) else (quote[0], quote[1]) + for quote in klass.QUOTES + ) + + klass.IDENTIFIERS = dict( + (identifier, identifier) + if isinstance(identifier, str) + else (identifier[0], identifier[1]) + for identifier in klass.IDENTIFIERS + ) + + klass.COMMENTS = dict( + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) + for comment in klass.COMMENTS + ) + + klass.KEYWORD_TRIE = new_trie( + key.upper() + for key, value in { + **klass.KEYWORDS, + **{comment: TokenType.COMMENT for comment in klass.COMMENTS}, + **{quote: TokenType.QUOTE for quote in klass.QUOTES}, + }.items() + if " " in key or any(single in key for single in klass.SINGLE_TOKENS) + ) + + return klass + + +class Tokenizer(metaclass=_Tokenizer): + SINGLE_TOKENS = { + "(": TokenType.L_PAREN, + ")": TokenType.R_PAREN, + "[": TokenType.L_BRACKET, + "]": TokenType.R_BRACKET, + "{": TokenType.L_BRACE, + "}": TokenType.R_BRACE, + "&": TokenType.AMP, + "^": TokenType.CARET, + ":": TokenType.COLON, + ",": TokenType.COMMA, + ".": TokenType.DOT, + "-": TokenType.DASH, + "=": TokenType.EQ, + ">": TokenType.GT, + "<": TokenType.LT, + "%": TokenType.MOD, + "!": TokenType.NOT, + "|": TokenType.PIPE, + "+": TokenType.PLUS, + ";": TokenType.SEMICOLON, + "/": TokenType.SLASH, + "*": TokenType.STAR, + "~": TokenType.TILDA, + "?": TokenType.PLACEHOLDER, + "#": TokenType.ANNOTATION, + "$": TokenType.DOLLAR, + # used for breaking a var like x'y' but nothing else + # the token type doesn't matter + "'": TokenType.QUOTE, + "`": TokenType.IDENTIFIER, + '"': TokenType.IDENTIFIER, + } + + QUOTES = ["'"] + + IDENTIFIERS = ['"'] + + ESCAPE = "'" + + KEYWORDS = { + "/*+": TokenType.HINT, + "*/": TokenType.HINT, + "==": TokenType.EQ, + "::": TokenType.DCOLON, + "||": TokenType.DPIPE, + ">=": TokenType.GTE, + "<=": TokenType.LTE, + "<>": TokenType.NEQ, + "!=": TokenType.NEQ, + "->": TokenType.ARROW, + "->>": TokenType.DARROW, + "#>": TokenType.HASH_ARROW, + "#>>": TokenType.DHASH_ARROW, + "ADD ARCHIVE": TokenType.ADD_FILE, + "ADD ARCHIVES": TokenType.ADD_FILE, + "ADD FILE": TokenType.ADD_FILE, + "ADD FILES": TokenType.ADD_FILE, + "ADD JAR": TokenType.ADD_FILE, + "ADD JARS": TokenType.ADD_FILE, + "ALL": TokenType.ALL, + "ALTER": TokenType.ALTER, + "ANALYZE": TokenType.ANALYZE, + "AND": TokenType.AND, + "ANY": TokenType.ANY, + "ASC": TokenType.ASC, + "AS": TokenType.ALIAS, + "AT TIME ZONE": TokenType.AT_TIME_ZONE, + "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, + "BEGIN": TokenType.BEGIN, + "BETWEEN": TokenType.BETWEEN, + "BUCKET": TokenType.BUCKET, + "CALL": TokenType.CALL, + "CACHE": TokenType.CACHE, + "UNCACHE": TokenType.UNCACHE, + "CASE": TokenType.CASE, + "CAST": TokenType.CAST, + "CHARACTER SET": TokenType.CHARACTER_SET, + "CHECK": TokenType.CHECK, + "CLUSTER BY": TokenType.CLUSTER_BY, + "COLLATE": TokenType.COLLATE, + "COMMENT": TokenType.SCHEMA_COMMENT, + "COMMIT": TokenType.COMMIT, + "CONSTRAINT": TokenType.CONSTRAINT, + "CONVERT": TokenType.CONVERT, + "CREATE": TokenType.CREATE, + "CROSS": TokenType.CROSS, + "CUBE": TokenType.CUBE, + "CURRENT_DATE": TokenType.CURRENT_DATE, + "CURRENT ROW": TokenType.CURRENT_ROW, + "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, + "DIV": TokenType.DIV, + "DEFAULT": TokenType.DEFAULT, + "DELETE": TokenType.DELETE, + "DESC": TokenType.DESC, + "DISTINCT": TokenType.DISTINCT, + "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, + "DROP": TokenType.DROP, + "ELSE": TokenType.ELSE, + "END": TokenType.END, + "ENGINE": TokenType.ENGINE, + "ESCAPE": TokenType.ESCAPE, + "EXCEPT": TokenType.EXCEPT, + "EXISTS": TokenType.EXISTS, + "EXPLAIN": TokenType.EXPLAIN, + "EXTRACT": TokenType.EXTRACT, + "FALSE": TokenType.FALSE, + "FETCH": TokenType.FETCH, + "FILTER": TokenType.FILTER, + "FIRST": TokenType.FIRST, + "FULL": TokenType.FULL, + "FUNCTION": TokenType.FUNCTION, + "FOLLOWING": TokenType.FOLLOWING, + "FOREIGN KEY": TokenType.FOREIGN_KEY, + "FORMAT": TokenType.FORMAT, + "FROM": TokenType.FROM, + "GROUP BY": TokenType.GROUP_BY, + "GROUPING SETS": TokenType.GROUPING_SETS, + "HAVING": TokenType.HAVING, + "IF": TokenType.IF, + "ILIKE": TokenType.ILIKE, + "IGNORE NULLS": TokenType.IGNORE_NULLS, + "IN": TokenType.IN, + "INDEX": TokenType.INDEX, + "INNER": TokenType.INNER, + "INSERT": TokenType.INSERT, + "INTERVAL": TokenType.INTERVAL, + "INTERSECT": TokenType.INTERSECT, + "INTO": TokenType.INTO, + "IS": TokenType.IS, + "ISNULL": TokenType.ISNULL, + "JOIN": TokenType.JOIN, + "LATERAL": TokenType.LATERAL, + "LAZY": TokenType.LAZY, + "LEFT": TokenType.LEFT, + "LIKE": TokenType.LIKE, + "LIMIT": TokenType.LIMIT, + "LOCATION": TokenType.LOCATION, + "NEXT": TokenType.NEXT, + "NO ACTION": TokenType.NO_ACTION, + "NOT": TokenType.NOT, + "NULL": TokenType.NULL, + "NULLS FIRST": TokenType.NULLS_FIRST, + "NULLS LAST": TokenType.NULLS_LAST, + "OFFSET": TokenType.OFFSET, + "ON": TokenType.ON, + "ONLY": TokenType.ONLY, + "OPTIMIZE": TokenType.OPTIMIZE, + "OPTIONS": TokenType.OPTIONS, + "OR": TokenType.OR, + "ORDER BY": TokenType.ORDER_BY, + "ORDINALITY": TokenType.ORDINALITY, + "OUTER": TokenType.OUTER, + "OUT OF": TokenType.OUT_OF, + "OVER": TokenType.OVER, + "OVERWRITE": TokenType.OVERWRITE, + "PARTITION": TokenType.PARTITION, + "PARTITION BY": TokenType.PARTITION_BY, + "PARTITIONED BY": TokenType.PARTITIONED_BY, + "PERCENT": TokenType.PERCENT, + "PRECEDING": TokenType.PRECEDING, + "PRIMARY KEY": TokenType.PRIMARY_KEY, + "RANGE": TokenType.RANGE, + "RECURSIVE": TokenType.RECURSIVE, + "REGEXP": TokenType.RLIKE, + "REPLACE": TokenType.REPLACE, + "RESPECT NULLS": TokenType.RESPECT_NULLS, + "REFERENCES": TokenType.REFERENCES, + "RIGHT": TokenType.RIGHT, + "RLIKE": TokenType.RLIKE, + "ROLLUP": TokenType.ROLLUP, + "ROW": TokenType.ROW, + "ROWS": TokenType.ROWS, + "SELECT": TokenType.SELECT, + "SET": TokenType.SET, + "SHOW": TokenType.SHOW, + "SOME": TokenType.SOME, + "SORT BY": TokenType.SORT_BY, + "STORED": TokenType.STORED, + "TABLE": TokenType.TABLE, + "TABLE_FORMAT": TokenType.TABLE_FORMAT, + "TBLPROPERTIES": TokenType.PROPERTIES, + "TABLESAMPLE": TokenType.TABLE_SAMPLE, + "TEMP": TokenType.TEMPORARY, + "TEMPORARY": TokenType.TEMPORARY, + "THEN": TokenType.THEN, + "TRUE": TokenType.TRUE, + "TRUNCATE": TokenType.TRUNCATE, + "TRY_CAST": TokenType.TRY_CAST, + "UNBOUNDED": TokenType.UNBOUNDED, + "UNION": TokenType.UNION, + "UNNEST": TokenType.UNNEST, + "UPDATE": TokenType.UPDATE, + "USE": TokenType.USE, + "USING": TokenType.USING, + "VALUES": TokenType.VALUES, + "VIEW": TokenType.VIEW, + "WHEN": TokenType.WHEN, + "WHERE": TokenType.WHERE, + "WITH": TokenType.WITH, + "WITH TIME ZONE": TokenType.WITH_TIME_ZONE, + "WITHIN GROUP": TokenType.WITHIN_GROUP, + "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, + "ARRAY": TokenType.ARRAY, + "BOOL": TokenType.BOOLEAN, + "BOOLEAN": TokenType.BOOLEAN, + "BYTE": TokenType.TINYINT, + "TINYINT": TokenType.TINYINT, + "SHORT": TokenType.SMALLINT, + "SMALLINT": TokenType.SMALLINT, + "INT2": TokenType.SMALLINT, + "INTEGER": TokenType.INT, + "INT": TokenType.INT, + "INT4": TokenType.INT, + "LONG": TokenType.BIGINT, + "BIGINT": TokenType.BIGINT, + "INT8": TokenType.BIGINT, + "DECIMAL": TokenType.DECIMAL, + "MAP": TokenType.MAP, + "NUMBER": TokenType.DECIMAL, + "NUMERIC": TokenType.DECIMAL, + "FIXED": TokenType.DECIMAL, + "REAL": TokenType.FLOAT, + "FLOAT": TokenType.FLOAT, + "FLOAT4": TokenType.FLOAT, + "FLOAT8": TokenType.DOUBLE, + "DOUBLE": TokenType.DOUBLE, + "JSON": TokenType.JSON, + "CHAR": TokenType.CHAR, + "NCHAR": TokenType.NCHAR, + "VARCHAR": TokenType.VARCHAR, + "VARCHAR2": TokenType.VARCHAR, + "NVARCHAR": TokenType.NVARCHAR, + "NVARCHAR2": TokenType.NVARCHAR, + "STRING": TokenType.TEXT, + "TEXT": TokenType.TEXT, + "CLOB": TokenType.TEXT, + "BINARY": TokenType.BINARY, + "BLOB": TokenType.BINARY, + "BYTEA": TokenType.BINARY, + "TIMESTAMP": TokenType.TIMESTAMP, + "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, + "DATE": TokenType.DATE, + "DATETIME": TokenType.DATETIME, + "UNIQUE": TokenType.UNIQUE, + "STRUCT": TokenType.STRUCT, + } + + WHITE_SPACE = { + " ": TokenType.SPACE, + "\t": TokenType.SPACE, + "\n": TokenType.BREAK, + "\r": TokenType.BREAK, + "\r\n": TokenType.BREAK, + } + + COMMANDS = { + TokenType.ALTER, + TokenType.ADD_FILE, + TokenType.ANALYZE, + TokenType.BEGIN, + TokenType.CALL, + TokenType.COMMIT, + TokenType.EXPLAIN, + TokenType.OPTIMIZE, + TokenType.SET, + TokenType.SHOW, + TokenType.TRUNCATE, + TokenType.USE, + } + + # handle numeric literals like in hive (3L = BIGINT) + NUMERIC_LITERALS = {} + ENCODE = None + + COMMENTS = ["--", ("/*", "*/")] + KEYWORD_TRIE = None # autofilled + + __slots__ = ( + "sql", + "size", + "tokens", + "_start", + "_current", + "_line", + "_col", + "_char", + "_end", + "_peek", + ) + + def __init__(self): + """ + Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token` + """ + self.reset() + + def reset(self): + self.sql = "" + self.size = 0 + self.tokens = [] + self._start = 0 + self._current = 0 + self._line = 1 + self._col = 1 + + self._char = None + self._end = None + self._peek = None + + def tokenize(self, sql): + self.reset() + self.sql = sql + self.size = len(sql) + + while self.size and not self._end: + self._start = self._current + self._advance() + + if not self._char: + break + + white_space = self.WHITE_SPACE.get(self._char) + identifier_end = self.IDENTIFIERS.get(self._char) + + if white_space: + if white_space == TokenType.BREAK: + self._col = 1 + self._line += 1 + elif self._char == "0" and self._peek == "x": + self._scan_hex() + elif self._char.isdigit(): + self._scan_number() + elif identifier_end: + self._scan_identifier(identifier_end) + else: + self._scan_keywords() + return self.tokens + + def _chars(self, size): + if size == 1: + return self._char + start = self._current - 1 + end = start + size + if end <= self.size: + return self.sql[start:end] + return "" + + def _advance(self, i=1): + self._col += i + self._current += i + self._end = self._current >= self.size + self._char = self.sql[self._current - 1] + self._peek = self.sql[self._current] if self._current < self.size else "" + + @property + def _text(self): + return self.sql[self._start : self._current] + + def _add(self, token_type, text=None): + text = self._text if text is None else text + self.tokens.append(Token(token_type, text, self._line, self._col)) + + if token_type in self.COMMANDS and ( + len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON + ): + self._start = self._current + while not self._end and self._peek != ";": + self._advance() + if self._start < self._current: + self._add(TokenType.STRING) + + def _scan_keywords(self): + size = 0 + word = None + chars = self._text + char = chars + prev_space = False + skip = False + trie = self.KEYWORD_TRIE + + while chars: + if skip: + result = 1 + else: + result, trie = in_trie(trie, char.upper()) + + if result == 0: + break + if result == 2: + word = chars + size += 1 + end = self._current - 1 + size + + if end < self.size: + char = self.sql[end] + is_space = char in self.WHITE_SPACE + + if not is_space or not prev_space: + if is_space: + char = " " + chars += char + prev_space = is_space + skip = False + else: + skip = True + else: + chars = None + + if not word: + if self._char in self.SINGLE_TOKENS: + token = self.SINGLE_TOKENS[self._char] + if token == TokenType.ANNOTATION: + self._scan_annotation() + return + self._add(token) + return + self._scan_var() + return + + if self._scan_string(word): + return + if self._scan_comment(word): + return + + self._advance(size - 1) + self._add(self.KEYWORDS[word.upper()]) + + def _scan_comment(self, comment_start): + if comment_start not in self.COMMENTS: + return False + + comment_end = self.COMMENTS[comment_start] + + if comment_end: + comment_end_size = len(comment_end) + + while not self._end and self._chars(comment_end_size) != comment_end: + self._advance() + self._advance(comment_end_size - 1) + else: + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: + self._advance() + return True + + def _scan_annotation(self): + while ( + not self._end + and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK + and self._peek != "," + ): + self._advance() + self._add(TokenType.ANNOTATION, self._text[1:]) + + def _scan_number(self): + decimal = False + scientific = 0 + + while True: + if self._peek.isdigit(): + self._advance() + elif self._peek == "." and not decimal: + decimal = True + self._advance() + elif self._peek in ("-", "+") and scientific == 1: + scientific += 1 + self._advance() + elif self._peek.upper() == "E" and not scientific: + scientific += 1 + self._advance() + elif self._peek.isalpha(): + self._add(TokenType.NUMBER) + literal = [] + while self._peek.isalpha(): + literal.append(self._peek.upper()) + self._advance() + literal = "".join(literal) + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) + if token_type: + self._add(TokenType.DCOLON, "::") + return self._add(token_type, literal) + return self._advance(-len(literal)) + else: + return self._add(TokenType.NUMBER) + + def _scan_hex(self): + self._advance() + + while True: + char = self._peek.strip() + if char and char not in self.SINGLE_TOKENS: + self._advance() + else: + break + try: + self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}") + except ValueError: + self._add(TokenType.IDENTIFIER) + + def _scan_string(self, quote): + quote_end = self.QUOTES.get(quote) + if quote_end is None: + return False + + text = "" + self._advance(len(quote)) + quote_end_size = len(quote_end) + + while True: + if self._char == self.ESCAPE and self._peek == quote_end: + text += quote + self._advance(2) + else: + if self._chars(quote_end_size) == quote_end: + if quote_end_size > 1: + self._advance(quote_end_size - 1) + break + + if self._end: + raise RuntimeError( + f"Missing {quote} from {self._line}:{self._start}" + ) + text += self._char + self._advance() + + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text + text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text + self._add(TokenType.STRING, text) + return True + + def _scan_identifier(self, identifier_end): + while self._peek != identifier_end: + if self._end: + raise RuntimeError( + f"Missing {identifier_end} from {self._line}:{self._start}" + ) + self._advance() + self._advance() + self._add(TokenType.IDENTIFIER, self._text[1:-1]) + + def _scan_var(self): + while True: + char = self._peek.strip() + if char and char not in self.SINGLE_TOKENS: + self._advance() + else: + break + self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR)) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py new file mode 100644 index 0000000..e7ccb8e --- /dev/null +++ b/sqlglot/transforms.py @@ -0,0 +1,68 @@ +from sqlglot import expressions as exp + + +def unalias_group(expression): + """ + Replace references to select aliases in GROUP BY clauses. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() + 'SELECT a AS b FROM x GROUP BY 1' + """ + if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): + aliased_selects = { + e.alias: i + for i, e in enumerate(expression.parent.expressions, start=1) + if isinstance(e, exp.Alias) + } + + expression = expression.copy() + + for col in expression.find_all(exp.Column): + alias_index = aliased_selects.get(col.name) + if not col.table and alias_index: + col.replace(exp.Literal.number(alias_index)) + + return expression + + +def preprocess(transforms, to_sql): + """ + Create a new transform function that can be used a value in `Generator.TRANSFORMS` + to convert expressions to SQL. + + Args: + transforms (list[(exp.Expression) -> exp.Expression]): + Sequence of transform functions. These will be called in order. + to_sql ((sqlglot.generator.Generator, exp.Expression) -> str): + Final transform that converts the resulting expression to a SQL string. + Returns: + (sqlglot.generator.Generator, exp.Expression) -> str: + Function that can be used as a generator transform. + """ + + def _to_sql(self, expression): + expression = transforms[0](expression) + for t in transforms[1:]: + expression = t(expression) + return to_sql(self, expression) + + return _to_sql + + +def delegate(attr): + """ + Create a new method that delegates to `attr`. + + This is useful for creating `Generator.TRANSFORMS` functions that delegate + to existing generator methods. + """ + + def _transform(self, *args, **kwargs): + return getattr(self, attr)(*args, **kwargs) + + return _transform + + +UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} diff --git a/sqlglot/trie.py b/sqlglot/trie.py new file mode 100644 index 0000000..a234107 --- /dev/null +++ b/sqlglot/trie.py @@ -0,0 +1,27 @@ +def new_trie(keywords): + trie = {} + + for key in keywords: + current = trie + + for char in key: + current = current.setdefault(char, {}) + current[0] = True + + return trie + + +def in_trie(trie, key): + if not key: + return (0, trie) + + current = trie + + for char in key: + if char not in current: + return (0, current) + current = current[char] + + if 0 in current: + return (2, current) + return (1, current) -- cgit v1.2.3