diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-16 09:41:18 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-16 09:41:18 +0000 |
commit | 67578a7602a5be7eb51f324086c8d49bcf8b7498 (patch) | |
tree | 0b7515c922d1c383cea24af5175379cfc8edfd15 /sqlglot | |
parent | Releasing debian version 15.2.0-1. (diff) | |
download | sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.tar.xz sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.zip |
Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
37 files changed, 1304 insertions, 1184 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5b10852..2166e65 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -7,6 +7,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, + format_time_lambda, inline_array_sql, max_or_greatest, min_or_least, @@ -103,16 +104,26 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: class BigQuery(Dialect): - unnest_column_only = True - time_mapping = { - "%M": "%-M", - "%d": "%-d", - "%m": "%-m", - "%y": "%-y", - "%H": "%-H", - "%I": "%-I", - "%S": "%-S", - "%j": "%-j", + UNNEST_COLUMN_ONLY = True + + TIME_MAPPING = { + "%D": "%m/%d/%y", + } + + FORMAT_MAPPING = { + "DD": "%d", + "MM": "%m", + "MON": "%b", + "MONTH": "%B", + "YYYY": "%Y", + "YY": "%y", + "HH": "%I", + "HH12": "%I", + "HH24": "%H", + "MI": "%M", + "SS": "%S", + "SSSSS": "%f", + "TZH": "%z", } class Tokenizer(tokens.Tokenizer): @@ -142,6 +153,7 @@ class BigQuery(Dialect): "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, "RECORD": TokenType.STRUCT, + "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, } @@ -155,13 +167,21 @@ class BigQuery(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), ), - "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), + "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), + "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ), + "PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -172,15 +192,15 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), + "SPLIT": lambda args: exp.Split( + # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split + this=seq_get(args, 0), + expression=seq_get(args, 1) or exp.Literal.string(","), + ), "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), - "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), - "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), - "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), - "PARSE_TIMESTAMP": lambda args: exp.StrToTime( - this=seq_get(args, 1), format=seq_get(args, 0) - ), } FUNCTION_PARSERS = { @@ -274,9 +294,18 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.RegexpExtract: lambda self, e: self.func( + "REGEXP_EXTRACT", + e.this, + e.expression, + e.args.get("position"), + e.args.get("occurrence"), + ), + exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.Select: transforms.preprocess( [_unqualify_unnest, transforms.eliminate_distinct_on] ), + exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -295,7 +324,6 @@ class BigQuery(Dialect): exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", - exp.RegexpLike: rename_func("REGEXP_CONTAINS"), } TYPE_MAPPING = { @@ -315,6 +343,7 @@ class BigQuery(Dialect): exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARBINARY: "BYTES", exp.DataType.Type.VARCHAR: "STRING", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index fc48379..cfa9a7e 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -21,8 +21,9 @@ def _lower_func(sql: str) -> str: class ClickHouse(Dialect): - normalize_functions = None - null_ordering = "nulls_are_last" + NORMALIZE_FUNCTIONS: bool | str = False + NULL_ORDERING = "nulls_are_last" + STRICT_STRING_CONCAT = True class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] @@ -163,11 +164,11 @@ class ClickHouse(Dialect): return this - def _parse_position(self, haystack_first: bool = False) -> exp.Expression: + def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: return super()._parse_position(haystack_first=True) # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ - def _parse_cte(self) -> exp.Expression: + def _parse_cte(self) -> exp.CTE: index = self._index try: # WITH <identifier> AS <subquery expression> @@ -187,17 +188,19 @@ class ClickHouse(Dialect): ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: is_global = self._match(TokenType.GLOBAL) and self._prev kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev + if kind_pre: kind = self._match_set(self.JOIN_KINDS) and self._prev side = self._match_set(self.JOIN_SIDES) and self._prev return is_global, side, kind + return ( is_global, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]: join = super()._parse_join(skip_join_token) if join: @@ -205,9 +208,14 @@ class ClickHouse(Dialect): return join def _parse_function( - self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, ) -> t.Optional[exp.Expression]: - func = super()._parse_function(functions, anonymous) + func = super()._parse_function( + functions=functions, anonymous=anonymous, optional_parens=optional_parens + ) if isinstance(func, exp.Anonymous): params = self._parse_func_params(func) @@ -227,10 +235,12 @@ class ClickHouse(Dialect): ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): return self._parse_csv(self._parse_lambda) + if self._match(TokenType.L_PAREN): params = self._parse_csv(self._parse_lambda) self._match_r_paren(this) return params + return None def _parse_quantile(self) -> exp.Quantile: @@ -247,12 +257,12 @@ class ClickHouse(Dialect): def _parse_primary_key( self, wrapped_optional: bool = False, in_props: bool = False - ) -> exp.Expression: + ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: return super()._parse_primary_key( wrapped_optional=wrapped_optional or in_props, in_props=in_props ) - def _parse_on_property(self) -> t.Optional[exp.Property]: + def _parse_on_property(self) -> t.Optional[exp.Expression]: index = self._index if self._match_text_seq("CLUSTER"): this = self._parse_id_var() @@ -329,6 +339,16 @@ class ClickHouse(Dialect): "NAMED COLLECTION", } + def safeconcat_sql(self, expression: exp.SafeConcat) -> str: + # Clickhouse errors out if we try to cast a NULL value to TEXT + return self.func( + "CONCAT", + *[ + exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text")) + for e in expression.expressions + ], + ) + def cte_sql(self, expression: exp.CTE) -> str: if isinstance(expression.this, exp.Alias): return self.sql(expression, "this") diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4958bc6..f5d523b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -25,6 +25,8 @@ class Dialects(str, Enum): BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" + DATABRICKS = "databricks" + DRILL = "drill" DUCKDB = "duckdb" HIVE = "hive" MYSQL = "mysql" @@ -38,11 +40,9 @@ class Dialects(str, Enum): SQLITE = "sqlite" STARROCKS = "starrocks" TABLEAU = "tableau" + TERADATA = "teradata" TRINO = "trino" TSQL = "tsql" - DATABRICKS = "databricks" - DRILL = "drill" - TERADATA = "teradata" class _Dialect(type): @@ -76,16 +76,19 @@ class _Dialect(type): enum = Dialects.__members__.get(clsname.upper()) cls.classes[enum.value if enum is not None else clsname.lower()] = klass - klass.time_trie = new_trie(klass.time_mapping) - klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} - klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) + klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) + klass.FORMAT_TRIE = ( + new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE + ) + klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} + klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) klass.parser_class = getattr(klass, "Parser", Parser) klass.generator_class = getattr(klass, "Generator", Generator) - klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list( + klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] + klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( klass.tokenizer_class._IDENTIFIERS.items() )[0] @@ -99,43 +102,80 @@ class _Dialect(type): (None, None), ) - klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) - klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) - klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) - klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) + klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) + klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) + klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) + klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) - klass.tokenizer_class.identifiers_can_start_with_digit = ( - klass.identifiers_can_start_with_digit - ) + dialect_properties = { + **{ + k: v + for k, v in vars(klass).items() + if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") + }, + "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], + "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], + } + + # Pass required dialect properties to the tokenizer, parser and generator classes + for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): + for name, value in dialect_properties.items(): + if hasattr(subclass, name): + setattr(subclass, name, value) + + if not klass.STRICT_STRING_CONCAT: + klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe return klass class Dialect(metaclass=_Dialect): - index_offset = 0 - unnest_column_only = False - alias_post_tablesample = False - identifiers_can_start_with_digit = False - normalize_functions: t.Optional[str] = "upper" - null_ordering = "nulls_are_small" - - date_format = "'%Y-%m-%d'" - dateint_format = "'%Y%m%d'" - time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping: t.Dict[str, str] = {} - - # autofilled - quote_start = None - quote_end = None - identifier_start = None - identifier_end = None - - time_trie = None - inverse_time_mapping = None - inverse_time_trie = None - tokenizer_class = None - parser_class = None - generator_class = None + # Determines the base index offset for arrays + INDEX_OFFSET = 0 + + # If true unnest table aliases are considered only as column aliases + UNNEST_COLUMN_ONLY = False + + # Determines whether or not the table alias comes after tablesample + ALIAS_POST_TABLESAMPLE = False + + # Determines whether or not an unquoted identifier can start with a digit + IDENTIFIERS_CAN_START_WITH_DIGIT = False + + # Determines whether or not CONCAT's arguments must be strings + STRICT_STRING_CONCAT = False + + # Determines how function names are going to be normalized + NORMALIZE_FUNCTIONS: bool | str = "upper" + + # Indicates the default null ordering method to use if not explicitly set + # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" + NULL_ORDERING = "nulls_are_small" + + DATE_FORMAT = "'%Y-%m-%d'" + DATEINT_FORMAT = "'%Y%m%d'" + TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" + + # Custom time mappings in which the key represents dialect time format + # and the value represents a python time format + TIME_MAPPING: t.Dict[str, str] = {} + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time + # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE + # special syntax cast(x as date format 'yyyy') defaults to time_mapping + FORMAT_MAPPING: t.Dict[str, str] = {} + + # Autofilled + tokenizer_class = Tokenizer + parser_class = Parser + generator_class = Generator + + # A trie of the time_mapping keys + TIME_TRIE: t.Dict = {} + FORMAT_TRIE: t.Dict = {} + + INVERSE_TIME_MAPPING: t.Dict[str, str] = {} + INVERSE_TIME_TRIE: t.Dict = {} def __eq__(self, other: t.Any) -> bool: return type(self) == other @@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect): ) -> t.Optional[exp.Expression]: if isinstance(expression, str): return exp.Literal.string( - format_time( - expression[1:-1], # the time formats are quoted - cls.time_mapping, - cls.time_trie, - ) + # the time formats are quoted + format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) ) + if expression and expression.is_string: - return exp.Literal.string( - format_time( - expression.this, - cls.time_mapping, - cls.time_trie, - ) - ) + return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) + return expression def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: @@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect): @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() # type: ignore + self._tokenizer = self.tokenizer_class() return self._tokenizer def parser(self, **opts) -> Parser: - return self.parser_class( # type: ignore - **{ - "index_offset": self.index_offset, - "unnest_column_only": self.unnest_column_only, - "alias_post_tablesample": self.alias_post_tablesample, - "null_ordering": self.null_ordering, - **opts, - }, - ) + return self.parser_class(**opts) def generator(self, **opts) -> Generator: - return self.generator_class( # type: ignore - **{ - "quote_start": self.quote_start, - "quote_end": self.quote_end, - "bit_start": self.bit_start, - "bit_end": self.bit_end, - "hex_start": self.hex_start, - "hex_end": self.hex_end, - "byte_start": self.byte_start, - "byte_end": self.byte_end, - "raw_start": self.raw_start, - "raw_end": self.raw_end, - "identifier_start": self.identifier_start, - "identifier_end": self.identifier_end, - "string_escape": self.tokenizer_class.STRING_ESCAPES[0], - "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], - "index_offset": self.index_offset, - "time_mapping": self.inverse_time_mapping, - "time_trie": self.inverse_time_trie, - "unnest_column_only": self.unnest_column_only, - "alias_post_tablesample": self.alias_post_tablesample, - "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, - "normalize_functions": self.normalize_functions, - "null_ordering": self.null_ordering, - **opts, - } - ) + return self.generator_class(**opts) DialectType = t.Union[str, Dialect, t.Type[Dialect], None] @@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like( - this=exp.Lower(this=expression.this), - expression=expression.args["expression"], - ) + exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) ) @@ -359,6 +355,7 @@ def var_map_sql( for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) + return self.func(map_func_name, *args) @@ -381,7 +378,7 @@ def format_time_lambda( this=seq_get(args, 0), format=Dialect[dialect].format_time( seq_get(args, 1) - or (Dialect[dialect].time_format if default is True else default or None) + or (Dialect[dialect].TIME_FORMAT if default is True else default or None) ), ) @@ -437,9 +434,7 @@ def parse_date_delta_with_interval( expression = exp.Literal.number(expression.this) return expression_class( - this=args[0], - expression=expression, - unit=exp.Literal.string(interval.text("unit")), + this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) ) return func @@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), + this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) ) @@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: _dialect = Dialect.get_or_raise(dialect) time_format = self.format_time(expression) - if time_format and time_format not in (_dialect.time_format, _dialect.date_format): + if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): return f"CAST({str_to_time_sql(self, expression)} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)" return _ts_or_ds_to_date_sql +def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: + this, *rest_args = expression.expressions + for arg in rest_args: + this = exp.DPipe(this=this, expression=arg) + + return self.sql(this) + + # Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 924b979..3cca986 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -16,21 +16,10 @@ from sqlglot.dialects.dialect import ( ) -def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: - return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" - - -def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: - time_format = self.format_time(expression) - if time_format and time_format not in (Drill.time_format, Drill.date_format): - return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" - return f"CAST({self.sql(expression, 'this')} AS DATE)" - - def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") - unit = exp.Var(this=expression.text("unit").upper() or "DAY") + unit = exp.var(expression.text("unit").upper() or "DAY") return ( f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" ) @@ -41,19 +30,19 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format == Drill.date_format: + if time_format == Drill.DATE_FORMAT: return f"CAST({this} AS DATE)" return f"TO_DATE({this}, {time_format})" class Drill(Dialect): - normalize_functions = None - null_ordering = "nulls_are_last" - date_format = "'yyyy-MM-dd'" - dateint_format = "'yyyyMMdd'" - time_format = "'yyyy-MM-dd HH:mm:ss'" + NORMALIZE_FUNCTIONS: bool | str = False + NULL_ORDERING = "nulls_are_last" + DATE_FORMAT = "'yyyy-MM-dd'" + DATEINT_FORMAT = "'yyyyMMdd'" + TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" - time_mapping = { + TIME_MAPPING = { "y": "%Y", "Y": "%Y", "YYYY": "%Y", @@ -93,6 +82,7 @@ class Drill(Dialect): class Parser(parser.Parser): STRICT_CAST = False + CONCAT_NULL_OUTPUTS_STRING = True FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -135,8 +125,8 @@ class Drill(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), - exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", - exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", + exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})", exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -154,7 +144,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f31da73..f0c1820 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -56,11 +56,7 @@ def _sort_array_reverse(args: t.List) -> exp.Expression: def _parse_date_diff(args: t.List) -> exp.Expression: - return exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), - ) + return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: @@ -90,7 +86,7 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract class DuckDB(Dialect): - null_ordering = "nulls_are_last" + NULL_ORDERING = "nulls_are_last" class Tokenizer(tokens.Tokenizer): KEYWORDS = { @@ -118,6 +114,8 @@ class DuckDB(Dialect): } class Parser(parser.Parser): + CONCAT_NULL_OUTPUTS_STRING = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, @@ -127,10 +125,7 @@ class DuckDB(Dialect): "DATE_DIFF": _parse_date_diff, "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( - this=exp.Div( - this=seq_get(args, 0), - expression=exp.Literal.number(1000), - ) + this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000)) ), "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, @@ -191,8 +186,8 @@ class DuckDB(Dialect): "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, - exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", - exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, @@ -242,11 +237,27 @@ class DuckDB(Dialect): STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} + UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren) + PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def interval_sql(self, expression: exp.Interval) -> str: + multiplier: t.Optional[int] = None + unit = expression.text("unit").lower() + + if unit.startswith("week"): + multiplier = 7 + if unit.startswith("quarter"): + multiplier = 90 + + if multiplier: + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" + + return super().interval_sql(expression) + def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS " ) -> str: diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 650a1e1..8847119 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -80,12 +80,12 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" + return f"{diff_sql}{multiplier_sql}" def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: this = expression.this - if not this.type: from sqlglot.optimizer.annotate_types import annotate_types @@ -113,7 +113,7 @@ def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> st def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format not in (Hive.time_format, Hive.date_format): + if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" return f"CAST({this} AS DATE)" @@ -121,7 +121,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format not in (Hive.time_format, Hive.date_format): + if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" return f"CAST({this} AS TIMESTAMP)" @@ -130,7 +130,7 @@ def _time_format( self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix ) -> t.Optional[str]: time_format = self.format_time(expression) - if time_format == Hive.time_format: + if time_format == Hive.TIME_FORMAT: return None return time_format @@ -144,16 +144,16 @@ def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format and time_format not in (Hive.time_format, Hive.date_format): + if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): return f"TO_DATE({this}, {time_format})" return f"TO_DATE({this})" class Hive(Dialect): - alias_post_tablesample = True - identifiers_can_start_with_digit = True + ALIAS_POST_TABLESAMPLE = True + IDENTIFIERS_CAN_START_WITH_DIGIT = True - time_mapping = { + TIME_MAPPING = { "y": "%Y", "Y": "%Y", "YYYY": "%Y", @@ -184,9 +184,9 @@ class Hive(Dialect): "EEEE": "%A", } - date_format = "'yyyy-MM-dd'" - dateint_format = "'yyyyMMdd'" - time_format = "'yyyy-MM-dd HH:mm:ss'" + DATE_FORMAT = "'yyyy-MM-dd'" + DATEINT_FORMAT = "'yyyyMMdd'" + TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] @@ -224,9 +224,7 @@ class Hive(Dialect): "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=seq_get(args, 0), - expression=seq_get(args, 1), - unit=exp.Literal.string("DAY"), + this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), "DATEDIFF": lambda args: exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 0)), @@ -234,10 +232,7 @@ class Hive(Dialect): ), "DATE_SUB": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), - expression=exp.Mul( - this=seq_get(args, 1), - expression=exp.Literal.number(-1), - ), + expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)), unit=exp.Literal.string("DAY"), ), "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( @@ -349,8 +344,8 @@ class Hive(Dialect): exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), exp.DateSub: _add_date_sql, - exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", - exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, @@ -415,10 +410,7 @@ class Hive(Dialect): ) def with_properties(self, properties: exp.Properties) -> str: - return self.properties( - properties, - prefix=self.seg("TBLPROPERTIES"), - ) + return self.properties(properties, prefix=self.seg("TBLPROPERTIES")) def datatype_sql(self, expression: exp.DataType) -> str: if ( diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 75023ff..d2462e1 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -94,10 +94,10 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e class MySQL(Dialect): - time_format = "'%Y-%m-%d %T'" + TIME_FORMAT = "'%Y-%m-%d %T'" # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions - time_mapping = { + TIME_MAPPING = { "%M": "%B", "%c": "%-m", "%e": "%-d", @@ -128,6 +128,7 @@ class MySQL(Dialect): "MEDIUMBLOB": TokenType.MEDIUMBLOB, "MEDIUMTEXT": TokenType.MEDIUMTEXT, "SEPARATOR": TokenType.SEPARATOR, + "ENUM": TokenType.ENUM, "START": TokenType.BEGIN, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -279,6 +280,16 @@ class MySQL(Dialect): "SWAPS", } + TYPE_TOKENS = { + *parser.Parser.TYPE_TOKENS, + TokenType.SET, + } + + ENUM_TYPE_TOKENS = { + *parser.Parser.ENUM_TYPE_TOKENS, + TokenType.SET, + } + LOG_DEFAULTS_TO_LN = True def _parse_show_mysql( @@ -372,12 +383,7 @@ class MySQL(Dialect): else: collate = None - return self.expression( - exp.SetItem, - this=charset, - collate=collate, - kind="NAMES", - ) + return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -472,9 +478,7 @@ class MySQL(Dialect): def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str: sql = self.sql(expression, arg) - if not sql: - return "" - return f" {prefix} {sql}" + return f" {prefix} {sql}" if sql else "" def _oldstyle_limit_sql(self, expression: exp.Show) -> str: limit = self.sql(expression, "limit") diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 7722753..8d35e92 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -24,21 +24,15 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: if self._match_text_seq("COLUMNS"): columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True))) - return self.expression( - exp.XMLTable, - this=this, - passing=passing, - columns=columns, - by_ref=by_ref, - ) + return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref) class Oracle(Dialect): - alias_post_tablesample = True + ALIAS_POST_TABLESAMPLE = True # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes - time_mapping = { + TIME_MAPPING = { "AM": "%p", # Meridian indicator with or without periods "A.M.": "%p", # Meridian indicator with or without periods "PM": "%p", # Meridian indicator with or without periods @@ -87,7 +81,7 @@ class Oracle(Dialect): column.set("join_mark", self._match(TokenType.JOIN_MARKER)) return column - def _parse_hint(self) -> t.Optional[exp.Expression]: + def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): start = self._curr while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH): @@ -129,7 +123,7 @@ class Oracle(Dialect): exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, - exp.IfNull: rename_func("NVL"), + exp.Coalesce: rename_func("NVL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), @@ -179,7 +173,6 @@ class Oracle(Dialect): "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NVARCHAR2": TokenType.NVARCHAR, - "RETURNING": TokenType.RETURNING, "SAMPLE": TokenType.TABLE_SAMPLE, "START": TokenType.BEGIN, "TOP": TokenType.TOP, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 8d84024..8c2a4ab 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -183,9 +183,10 @@ def _to_timestamp(args: t.List) -> exp.Expression: class Postgres(Dialect): - null_ordering = "nulls_are_large" - time_format = "'YYYY-MM-DD HH24:MI:SS'" - time_mapping = { + INDEX_OFFSET = 1 + NULL_ORDERING = "nulls_are_large" + TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" + TIME_MAPPING = { "AM": "%p", "PM": "%p", "D": "%u", # 1-based day of week @@ -241,7 +242,6 @@ class Postgres(Dialect): "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, - "RETURNING": TokenType.RETURNING, "REVOKE": TokenType.COMMAND, "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, @@ -258,6 +258,7 @@ class Postgres(Dialect): class Parser(parser.Parser): STRICT_CAST = False + CONCAT_NULL_OUTPUTS_STRING = True FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -268,6 +269,7 @@ class Postgres(Dialect): "NOW": exp.CurrentTimestamp.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_TIMESTAMP": _to_timestamp, + "UNNEST": exp.Explode.from_arg_list, } FUNCTION_PARSERS = { @@ -303,7 +305,7 @@ class Postgres(Dialect): value = self._parse_bitwise() if part and part.is_string: - part = exp.Var(this=part.name) + part = exp.var(part.name) return self.expression(exp.Extract, this=part, expression=value) @@ -328,6 +330,7 @@ class Postgres(Dialect): **generator.Generator.TRANSFORMS, exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), + exp.Explode: rename_func("UNNEST"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index d839864..a8a9884 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -102,7 +102,7 @@ def _str_to_time_sql( def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) - if time_format and time_format not in (Presto.time_format, Presto.date_format): + if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" @@ -119,7 +119,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s exp.Literal.number(1), exp.Literal.number(10), ), - Presto.date_format, + Presto.DATE_FORMAT, ) return self.func( @@ -145,9 +145,7 @@ def _approx_percentile(args: t.List) -> exp.Expression: ) if len(args) == 3: return exp.ApproxQuantile( - this=seq_get(args, 0), - quantile=seq_get(args, 1), - accuracy=seq_get(args, 2), + this=seq_get(args, 0), quantile=seq_get(args, 1), accuracy=seq_get(args, 2) ) return exp.ApproxQuantile.from_arg_list(args) @@ -160,10 +158,8 @@ def _from_unixtime(args: t.List) -> exp.Expression: minutes=seq_get(args, 2), ) if len(args) == 2: - return exp.UnixToTime( - this=seq_get(args, 0), - zone=seq_get(args, 1), - ) + return exp.UnixToTime(this=seq_get(args, 0), zone=seq_get(args, 1)) + return exp.UnixToTime.from_arg_list(args) @@ -173,21 +169,17 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression: unnest = exp.Unnest(expressions=[expression.this]) if expression.alias: - return exp.alias_( - unnest, - alias="_u", - table=[expression.alias], - copy=False, - ) + return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) return unnest return expression class Presto(Dialect): - index_offset = 1 - null_ordering = "nulls_are_last" - time_format = MySQL.time_format - time_mapping = MySQL.time_mapping + INDEX_OFFSET = 1 + NULL_ORDERING = "nulls_are_last" + TIME_FORMAT = MySQL.TIME_FORMAT + TIME_MAPPING = MySQL.TIME_MAPPING + STRICT_STRING_CONCAT = True class Tokenizer(tokens.Tokenizer): KEYWORDS = { @@ -205,14 +197,10 @@ class Presto(Dialect): "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DATE_DIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), @@ -225,9 +213,7 @@ class Presto(Dialect): "NOW": exp.CurrentTimestamp.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( - this=seq_get(args, 0), - substr=seq_get(args, 1), - instance=seq_get(args, 2), + this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_HEX": exp.Hex.from_arg_list, @@ -242,7 +228,7 @@ class Presto(Dialect): INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False TABLE_HINTS = False - IS_BOOL = False + IS_BOOL_ALLOWED = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -284,10 +270,10 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), - exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", - exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.Decode: _decode_sql, - exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", exp.Encode: _encode_sql, exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), @@ -322,7 +308,7 @@ class Presto(Dialect): exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", + exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), @@ -367,8 +353,16 @@ class Presto(Dialect): to = target_type.copy() if target_type is start.to: - end = exp.Cast(this=end, to=to) + end = exp.cast(end, to) else: - start = exp.Cast(this=start, to=to) + start = exp.cast(start, to) return self.func("SEQUENCE", start, end, step) + + def offset_limit_modifiers( + self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] + ) -> t.List[str]: + return [ + self.sql(expression, "offset"), + self.sql(limit), + ] diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b0a6774..a7e25fa 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp, transforms -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -14,9 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx class Redshift(Postgres): - time_format = "'YYYY-MM-DD HH:MI:SS'" - time_mapping = { - **Postgres.time_mapping, + TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" + TIME_MAPPING = { + **Postgres.TIME_MAPPING, "MON": "%b", "HH": "%H", } @@ -51,7 +51,7 @@ class Redshift(Postgres): and this.expressions and this.expressions[0].this == exp.column("MAX") ): - this.set("expressions", [exp.Var(this="MAX")]) + this.set("expressions", [exp.var("MAX")]) return this @@ -94,6 +94,7 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, + exp.Concat: concat_to_dpipe_sql, exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this @@ -106,6 +107,7 @@ class Redshift(Postgres): exp.FromBase: rename_func("STRTOL"), exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, + exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.TsOrDsToDate: lambda self, e: self.sql(e.this), @@ -170,6 +172,6 @@ class Redshift(Postgres): precision = expression.args.get("expressions") if not precision: - expression.append("expressions", exp.Var(this="MAX")) + expression.append("expressions", exp.var("MAX")) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 821d991..148b6d8 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -167,10 +167,10 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression: class Snowflake(Dialect): - null_ordering = "nulls_are_large" - time_format = "'yyyy-mm-dd hh24:mi:ss'" + NULL_ORDERING = "nulls_are_large" + TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" - time_mapping = { + TIME_MAPPING = { "YYYY": "%Y", "yyyy": "%Y", "YY": "%y", @@ -210,14 +210,10 @@ class Snowflake(Dialect): "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DATEDIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, @@ -246,9 +242,7 @@ class Snowflake(Dialect): COLUMN_OPERATORS = { **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( - exp.Bracket, - this=this, - expressions=[path], + exp.Bracket, this=this, expressions=[path] ), } @@ -275,6 +269,7 @@ class Snowflake(Dialect): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] + COMMENTS = ["--", "//", ("/*", "*/")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index bf24240..ed6992d 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -38,7 +38,7 @@ def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format == Hive.date_format: + if time_format == Hive.DATE_FORMAT: return f"TO_DATE({this})" return f"TO_DATE({this}, {time_format})" @@ -133,13 +133,13 @@ class Spark2(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), - "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)), ), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), "BOOLEAN": _parse_as_cast("boolean"), + "DATE": _parse_as_cast("date"), "DOUBLE": _parse_as_cast("double"), "FLOAT": _parse_as_cast("float"), "INT": _parse_as_cast("int"), @@ -162,11 +162,9 @@ class Spark2(Hive): def _parse_add_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - def _parse_drop_column(self) -> t.Optional[exp.Expression]: + def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: return self._match_text_seq("DROP", "COLUMNS") and self.expression( - exp.Drop, - this=self._parse_schema(), - kind="COLUMNS", + exp.Drop, this=self._parse_schema(), kind="COLUMNS" ) def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 4e800b0..3b837ea 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + concat_to_dpipe_sql, count_if_to_sum, no_ilike_sql, no_pivot_sql, @@ -62,10 +63,6 @@ class SQLite(Dialect): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - } - class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -100,6 +97,7 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Concat: concat_to_dpipe_sql, exp.CountIf: count_if_to_sum, exp.Create: transforms.preprocess([_transform_create]), exp.CurrentDate: lambda *_: "CURRENT_DATE", @@ -116,6 +114,7 @@ class SQLite(Dialect): exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), exp.Pivot: no_pivot_sql, + exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_qualify] ), diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index d5fba17..67ef76b 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, transforms -from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.dialect import Dialect, rename_func class Tableau(Dialect): @@ -11,6 +11,7 @@ class Tableau(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Coalesce: rename_func("IFNULL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), } @@ -25,9 +26,6 @@ class Tableau(Dialect): false = self.sql(expression, "false") return f"IF {this} THEN {true} ELSE {false} END" - def coalesce_sql(self, expression: exp.Coalesce) -> str: - return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" - def count_sql(self, expression: exp.Count) -> str: this = expression.this if isinstance(this, exp.Distinct): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 514aecb..d5e5dd8 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,18 +1,32 @@ from __future__ import annotations -import typing as t - from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - format_time_lambda, - max_or_greatest, - min_or_least, -) +from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least from sqlglot.tokens import TokenType class Teradata(Dialect): + TIME_MAPPING = { + "Y": "%Y", + "YYYY": "%Y", + "YY": "%y", + "MMMM": "%B", + "MMM": "%b", + "DD": "%d", + "D": "%-d", + "HH": "%H", + "H": "%-H", + "MM": "%M", + "M": "%-M", + "SS": "%S", + "S": "%-S", + "SSSSSS": "%f", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", + } + class Tokenizer(tokens.Tokenizer): # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance KEYWORDS = { @@ -31,7 +45,7 @@ class Teradata(Dialect): "ST_GEOMETRY": TokenType.GEOMETRY, } - # teradata does not support % for modulus + # Teradata does not support % as a modulo operator SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS} SINGLE_TOKENS.pop("%") @@ -101,7 +115,7 @@ class Teradata(Dialect): # FROM before SET in Teradata UPDATE syntax # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause - def _parse_update(self) -> exp.Expression: + def _parse_update(self) -> exp.Update: return self.expression( exp.Update, **{ # type: ignore @@ -122,14 +136,6 @@ class Teradata(Dialect): return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) - def _parse_cast(self, strict: bool) -> exp.Expression: - cast = t.cast(exp.Cast, super()._parse_cast(strict)) - if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT): - return format_time_lambda(exp.TimeToStr, "teradata")( - [cast.this, self._parse_string()] - ) - return cast - class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False @@ -151,7 +157,7 @@ class Teradata(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), - exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", + exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index f6ad888..6d674f5 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -64,9 +64,9 @@ def _format_time_lambda( format=exp.Literal.string( format_time( args[0].name, - {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} + {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping - else TSQL.time_mapping, + else TSQL.TIME_MAPPING, ) ), ) @@ -86,9 +86,9 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr( this=args[0], format=exp.Literal.string( - format_time(fmt.name, TSQL.format_time_mapping) + format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING) if len(fmt.name) == 1 - else format_time(fmt.name, TSQL.time_mapping) + else format_time(fmt.name, TSQL.TIME_MAPPING) ), ) @@ -138,7 +138,7 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim if isinstance(expression, exp.NumberToStr) else exp.Literal.string( format_time( - expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping) + expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING) ) ) ) @@ -166,10 +166,10 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s class TSQL(Dialect): - null_ordering = "nulls_are_small" - time_format = "'yyyy-mm-dd hh:mm:ss'" + NULL_ORDERING = "nulls_are_small" + TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" - time_mapping = { + TIME_MAPPING = { "year": "%Y", "qq": "%q", "q": "%q", @@ -213,7 +213,7 @@ class TSQL(Dialect): "yy": "%y", } - convert_format_mapping = { + CONVERT_FORMAT_MAPPING = { "0": "%b %d %Y %-I:%M%p", "1": "%m/%d/%y", "2": "%y.%m.%d", @@ -253,8 +253,8 @@ class TSQL(Dialect): "120": "%Y-%m-%d %H:%M:%S", "121": "%Y-%m-%d %H:%M:%S.%f", } - # not sure if complete - format_time_mapping = { + + FORMAT_TIME_MAPPING = { "y": "%B %Y", "d": "%m/%d/%Y", "H": "%-H", @@ -312,9 +312,7 @@ class TSQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), + this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -363,6 +361,8 @@ class TSQL(Dialect): LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True + CONCAT_NULL_OUTPUTS_STRING = True + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -400,7 +400,7 @@ class TSQL(Dialect): table.set("system_time", self._parse_system_time()) return table - def _parse_returns(self) -> exp.Expression: + def _parse_returns(self) -> exp.ReturnsProperty: table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) returns = super()._parse_returns() returns.set("table", table) @@ -423,12 +423,12 @@ class TSQL(Dialect): format_val = self._parse_number() format_val_name = format_val.name if format_val else "" - if format_val_name not in TSQL.convert_format_mapping: + if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING: raise ValueError( f"CONVERT function at T-SQL does not support format style {format_val_name}" ) - format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name]) + format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name]) # Check whether the convert entails a string to date format if to.this == DataType.Type.DATE: diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 51cffbd..d2c4e72 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -151,6 +151,7 @@ ENV = { "CAST": cast, "COALESCE": lambda *args: next((a for a in args if a is not None), None), "CONCAT": null_if_any(lambda *args: "".join(args)), + "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)), "CONCATWS": null_if_any(lambda this, *args: this.join(args)), "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)), "DIV": null_if_any(lambda e, this: e / this), @@ -159,7 +160,6 @@ ENV = { "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), "GT": null_if_any(lambda this, e: this > e), "GTE": null_if_any(lambda this, e: this >= e), - "IFNULL": lambda e, alt: alt if e is None else e, "IF": lambda predicate, true, false: true if predicate else false, "INTDIV": null_if_any(lambda e, this: e // this), "INTERVAL": interval, diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index f114e5c..3f96f90 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -394,7 +394,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str: names = {e.name.lower() for e in e.expressions} e = e.transform( - lambda n: exp.Var(this=n.name) + lambda n: exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n ) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index da4a4ed..c7d4664 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1500,6 +1500,7 @@ class Index(Expression): arg_types = { "this": False, "table": False, + "using": False, "where": False, "columns": False, "unique": False, @@ -1623,7 +1624,7 @@ class Lambda(Expression): class Limit(Expression): - arg_types = {"this": False, "expression": True} + arg_types = {"this": False, "expression": True, "offset": False} class Literal(Condition): @@ -1869,6 +1870,10 @@ class EngineProperty(Property): arg_types = {"this": True} +class ToTableProperty(Property): + arg_types = {"this": True} + + class ExecuteAsProperty(Property): arg_types = {"this": True} @@ -3072,12 +3077,35 @@ class Select(Subqueryable): Returns: The modified expression. """ - inst = _maybe_copy(self, copy) inst.set("locks", [Lock(update=update)]) return inst + def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select: + """ + Set hints for this expression. + + Examples: + >>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark") + 'SELECT /*+ BROADCAST(y) */ x FROM tbl' + + Args: + hints: The SQL code strings to parse as the hints. + If an `Expression` instance is passed, it will be used as-is. + dialect: The dialect used to parse the hints. + copy: If `False`, modify this expression instance in-place. + + Returns: + The modified expression. + """ + inst = _maybe_copy(self, copy) + inst.set( + "hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints]) + ) + + return inst + @property def named_selects(self) -> t.List[str]: return [e.output_name for e in self.expressions if e.alias_or_name] @@ -3244,6 +3272,7 @@ class DataType(Expression): DATE = auto() DATETIME = auto() DATETIME64 = auto() + ENUM = auto() INT4RANGE = auto() INT4MULTIRANGE = auto() INT8RANGE = auto() @@ -3284,6 +3313,7 @@ class DataType(Expression): OBJECT = auto() ROWVERSION = auto() SERIAL = auto() + SET = auto() SMALLINT = auto() SMALLMONEY = auto() SMALLSERIAL = auto() @@ -3334,6 +3364,7 @@ class DataType(Expression): NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES} TEMPORAL_TYPES = { + Type.TIME, Type.TIMESTAMP, Type.TIMESTAMPTZ, Type.TIMESTAMPLTZ, @@ -3342,6 +3373,8 @@ class DataType(Expression): Type.DATETIME64, } + META_TYPES = {"UNKNOWN", "NULL"} + @classmethod def build( cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs @@ -3349,8 +3382,9 @@ class DataType(Expression): from sqlglot import parse_one if isinstance(dtype, str): - if dtype.upper() in cls.Type.__members__: - data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) + upper = dtype.upper() + if upper in DataType.META_TYPES: + data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) @@ -3483,6 +3517,10 @@ class Dot(Binary): def name(self) -> str: return self.expression.name + @property + def output_name(self) -> str: + return self.name + @classmethod def build(self, expressions: t.Sequence[Expression]) -> Dot: """Build a Dot object with a sequence of expressions.""" @@ -3502,6 +3540,10 @@ class DPipe(Binary): pass +class SafeDPipe(DPipe): + pass + + class EQ(Binary, Predicate): pass @@ -3615,6 +3657,10 @@ class Not(Unary): class Paren(Unary): arg_types = {"this": True, "with": False} + @property + def output_name(self) -> str: + return self.this.name + class Neg(Unary): pass @@ -3904,6 +3950,7 @@ class Ceil(Func): class Coalesce(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True + _sql_names = ["COALESCE", "IFNULL", "NVL"] class Concat(Func): @@ -3911,12 +3958,17 @@ class Concat(Func): is_var_len_args = True +class SafeConcat(Concat): + pass + + class ConcatWs(Concat): _sql_names = ["CONCAT_WS"] class Count(AggFunc): - arg_types = {"this": False} + arg_types = {"this": False, "expressions": False} + is_var_len_args = True class CountIf(AggFunc): @@ -4049,6 +4101,11 @@ class DateToDi(Func): pass +class Date(Func): + arg_types = {"expressions": True} + is_var_len_args = True + + class Day(Func): pass @@ -4102,11 +4159,6 @@ class If(Func): arg_types = {"this": True, "true": True, "false": False} -class IfNull(Func): - arg_types = {"this": True, "expression": False} - _sql_names = ["IFNULL", "NVL"] - - class Initcap(Func): arg_types = {"this": True, "expression": False} @@ -5608,22 +5660,27 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) - expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) -def column_table_names(expression: Expression) -> t.List[str]: +def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: """ Return all table names referenced through columns in an expression. Example: >>> import sqlglot - >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")) - ['c', 'a'] + >>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))) + ['a', 'c'] Args: expression: expression to find table names. + exclude: a table name to exclude Returns: A list of unique names. """ - return list(dict.fromkeys(column.table for column in expression.find_all(Column))) + return { + table + for table in (column.table for column in expression.find_all(Column)) + if table and table != exclude + } def table_name(table: Table | str) -> str: @@ -5649,12 +5706,13 @@ def table_name(table: Table | str) -> str: return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part) -def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E: +def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E: """Replace all tables in expression according to the mapping. Args: expression: expression node to be transformed and replaced. mapping: mapping of table names. + copy: whether or not to copy the expression. Examples: >>> from sqlglot import exp, parse_one @@ -5675,7 +5733,7 @@ def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E: ) return node - return expression.transform(_replace_tables) + return expression.transform(_replace_tables, copy=copy) def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 97cbe15..d3cf9f0 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot") class Generator: """ - Generator interprets the given syntax tree and produces a SQL string as an output. + Generator converts a given syntax tree to the corresponding SQL string. Args: - time_mapping (dict): the dictionary of custom time mappings in which the key - represents a python time format and the output the target time format - time_trie (trie): a trie of the time_mapping keys - pretty (bool): if set to True the returned string will be formatted. Default: False. - quote_start (str): specifies which starting character to use to delimit quotes. Default: '. - quote_end (str): specifies which ending character to use to delimit quotes. Default: '. - identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". - identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". - bit_start (str): specifies which starting character to use to delimit bit literals. Default: None. - bit_end (str): specifies which ending character to use to delimit bit literals. Default: None. - hex_start (str): specifies which starting character to use to delimit hex literals. Default: None. - hex_end (str): specifies which ending character to use to delimit hex literals. Default: None. - byte_start (str): specifies which starting character to use to delimit byte literals. Default: None. - byte_end (str): specifies which ending character to use to delimit byte literals. Default: None. - raw_start (str): specifies which starting character to use to delimit raw literals. Default: None. - raw_end (str): specifies which ending character to use to delimit raw literals. Default: None. - identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. - normalize (bool): if set to True all identifiers will lower cased - string_escape (str): specifies a string escape character. Default: '. - identifier_escape (str): specifies an identifier escape character. Default: ". - pad (int): determines padding in a formatted string. Default: 2. - indent (int): determines the size of indentation in a formatted string. Default: 4. - unnest_column_only (bool): if true unnest table aliases are considered only as column aliases - normalize_functions (str): normalize function names, "upper", "lower", or None - Default: "upper" - alias_post_tablesample (bool): if the table alias comes after tablesample - Default: False - identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit - Default: False - unsupported_level (ErrorLevel): determines the generator's behavior when it encounters - unsupported expressions. Default ErrorLevel.WARN. - null_ordering (str): Indicates the default null ordering method to use if not explicitly set. - Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". - Default: "nulls_are_small" - max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. + pretty: Whether or not to format the produced SQL string. + Default: False. + identify: Determines when an identifier should be quoted. Possible values are: + False (default): Never quote, except in cases where it's mandatory by the dialect. + True or 'always': Always quote. + 'safe': Only quote identifiers that are case insensitive. + normalize: Whether or not to normalize identifiers to lowercase. + Default: False. + pad: Determines the pad size in a formatted string. + Default: 2. + indent: Determines the indentation size in a formatted string. + Default: 2. + normalize_functions: Whether or not to normalize all function names. Possible values are: + "upper" or True (default): Convert names to uppercase. + "lower": Convert names to lowercase. + False: Disables function name normalization. + unsupported_level: Determines the generator's behavior when it encounters unsupported expressions. + Default ErrorLevel.WARN. + max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 - leading_comma (bool): if the the comma is leading or trailing in select statements + leading_comma: Determines whether or not the comma is leading or trailing in select expressions. + This is only relevant when generating in pretty mode. Default: False max_text_width: The max number of characters in a segment before creating new lines in pretty mode. The default is on the smaller end because the length only represents a segment and not the true @@ -86,6 +71,7 @@ class Generator: exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.TemporaryProperty: lambda self, e: f"TEMPORARY", + exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", exp.TransientProperty: lambda self, e: "TRANSIENT", exp.StabilityProperty: lambda self, e: e.name, exp.VolatileProperty: lambda self, e: "VOLATILE", @@ -138,15 +124,24 @@ class Generator: # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" - # Whether a table is allowed to be renamed with a db + # Whether or not a table is allowed to be renamed with a db RENAME_TABLE_WITH_DB = True # The separator for grouping sets and rollups GROUPINGS_SEP = "," - # The string used for creating index on a table + # The string used for creating an index on a table INDEX_ON = "ON" + # Whether or not join hints should be generated + JOIN_HINTS = True + + # Whether or not table hints should be generated + TABLE_HINTS = True + + # Whether or not comparing against booleans (e.g. x IS TRUE) is supported + IS_BOOL_ALLOWED = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -228,6 +223,7 @@ class Generator: exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, + exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA, exp.TransientProperty: exp.Properties.Location.POST_CREATE, exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.POST_CREATE, @@ -235,128 +231,110 @@ class Generator: exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } - JOIN_HINTS = True - TABLE_HINTS = True - IS_BOOL = True - + # Keywords that can't be used as unquoted identifier names RESERVED_KEYWORDS: t.Set[str] = set() - WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) - UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren) + + # Expressions whose comments are separated from them for better formatting + WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Select, + exp.From, + exp.Where, + exp.With, + ) + + # Expressions that can remain unwrapped when appearing in the context of an INTERVAL + UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Column, + exp.Literal, + exp.Neg, + exp.Paren, + ) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" + # Autofilled + INVERSE_TIME_MAPPING: t.Dict[str, str] = {} + INVERSE_TIME_TRIE: t.Dict = {} + INDEX_OFFSET = 0 + UNNEST_COLUMN_ONLY = False + ALIAS_POST_TABLESAMPLE = False + IDENTIFIERS_CAN_START_WITH_DIGIT = False + STRICT_STRING_CONCAT = False + NORMALIZE_FUNCTIONS: bool | str = "upper" + NULL_ORDERING = "nulls_are_small" + + # Delimiters for quotes, identifiers and the corresponding escape characters + QUOTE_START = "'" + QUOTE_END = "'" + IDENTIFIER_START = '"' + IDENTIFIER_END = '"' + STRING_ESCAPE = "'" + IDENTIFIER_ESCAPE = '"' + + # Delimiters for bit, hex, byte and raw literals + BIT_START: t.Optional[str] = None + BIT_END: t.Optional[str] = None + HEX_START: t.Optional[str] = None + HEX_END: t.Optional[str] = None + BYTE_START: t.Optional[str] = None + BYTE_END: t.Optional[str] = None + RAW_START: t.Optional[str] = None + RAW_END: t.Optional[str] = None + __slots__ = ( - "time_mapping", - "time_trie", "pretty", - "quote_start", - "quote_end", - "identifier_start", - "identifier_end", - "bit_start", - "bit_end", - "hex_start", - "hex_end", - "byte_start", - "byte_end", - "raw_start", - "raw_end", "identify", "normalize", - "string_escape", - "identifier_escape", "pad", - "index_offset", - "unnest_column_only", - "alias_post_tablesample", - "identifiers_can_start_with_digit", + "_indent", "normalize_functions", "unsupported_level", - "unsupported_messages", - "null_ordering", "max_unsupported", - "_indent", + "leading_comma", + "max_text_width", + "comments", + "unsupported_messages", "_escaped_quote_end", "_escaped_identifier_end", - "_leading_comma", - "_max_text_width", - "_comments", "_cache", ) def __init__( self, - time_mapping=None, - time_trie=None, - pretty=None, - quote_start=None, - quote_end=None, - identifier_start=None, - identifier_end=None, - bit_start=None, - bit_end=None, - hex_start=None, - hex_end=None, - byte_start=None, - byte_end=None, - raw_start=None, - raw_end=None, - identify=False, - normalize=False, - string_escape=None, - identifier_escape=None, - pad=2, - indent=2, - index_offset=0, - unnest_column_only=False, - alias_post_tablesample=False, - identifiers_can_start_with_digit=False, - normalize_functions="upper", - unsupported_level=ErrorLevel.WARN, - null_ordering=None, - max_unsupported=3, - leading_comma=False, - max_text_width=80, - comments=True, + pretty: t.Optional[bool] = None, + identify: str | bool = False, + normalize: bool = False, + pad: int = 2, + indent: int = 2, + normalize_functions: t.Optional[str | bool] = None, + unsupported_level: ErrorLevel = ErrorLevel.WARN, + max_unsupported: int = 3, + leading_comma: bool = False, + max_text_width: int = 80, + comments: bool = True, ): import sqlglot - self.time_mapping = time_mapping or {} - self.time_trie = time_trie self.pretty = pretty if pretty is not None else sqlglot.pretty - self.quote_start = quote_start or "'" - self.quote_end = quote_end or "'" - self.identifier_start = identifier_start or '"' - self.identifier_end = identifier_end or '"' - self.bit_start = bit_start - self.bit_end = bit_end - self.hex_start = hex_start - self.hex_end = hex_end - self.byte_start = byte_start - self.byte_end = byte_end - self.raw_start = raw_start - self.raw_end = raw_end self.identify = identify self.normalize = normalize - self.string_escape = string_escape or "'" - self.identifier_escape = identifier_escape or '"' self.pad = pad - self.index_offset = index_offset - self.unnest_column_only = unnest_column_only - self.alias_post_tablesample = alias_post_tablesample - self.identifiers_can_start_with_digit = identifiers_can_start_with_digit - self.normalize_functions = normalize_functions + self._indent = indent self.unsupported_level = unsupported_level - self.unsupported_messages = [] self.max_unsupported = max_unsupported - self.null_ordering = null_ordering - self._indent = indent - self._escaped_quote_end = self.string_escape + self.quote_end - self._escaped_identifier_end = self.identifier_escape + self.identifier_end - self._leading_comma = leading_comma - self._max_text_width = max_text_width - self._comments = comments - self._cache = None + self.leading_comma = leading_comma + self.max_text_width = max_text_width + self.comments = comments + + # This is both a Dialect property and a Generator argument, so we prioritize the latter + self.normalize_functions = ( + self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions + ) + + self.unsupported_messages: t.List[str] = [] + self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END + self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END + self._cache: t.Optional[t.Dict[int, str]] = None def generate( self, @@ -364,17 +342,19 @@ class Generator: cache: t.Optional[t.Dict[int, str]] = None, ) -> str: """ - Generates a SQL string by interpreting the given syntax tree. + Generates the SQL string corresponding to the given syntax tree. - Args - expression: the syntax tree. - cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node. + Args: + expression: The syntax tree. + cache: An optional sql string cache. This leverages the hash of an Expression + which can be slow to compute, so only use it if you set _hash on each node. - Returns - the SQL string. + Returns: + The SQL string corresponding to `expression`. """ if cache is not None: self._cache = cache + self.unsupported_messages = [] sql = self.sql(expression).strip() self._cache = None @@ -414,7 +394,11 @@ class Generator: expression: t.Optional[exp.Expression] = None, comments: t.Optional[t.List[str]] = None, ) -> str: - comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore + comments = ( + ((expression and expression.comments) if comments is None else comments) # type: ignore + if self.comments + else None + ) if not comments or isinstance(expression, exp.Binary): return sql @@ -454,7 +438,7 @@ class Generator: return result def normalize_func(self, name: str) -> str: - if self.normalize_functions == "upper": + if self.normalize_functions == "upper" or self.normalize_functions is True: return name.upper() if self.normalize_functions == "lower": return name.lower() @@ -522,7 +506,7 @@ class Generator: else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - sql = self.maybe_comment(sql, expression) if self._comments and comment else sql + sql = self.maybe_comment(sql, expression) if self.comments and comment else sql if self._cache is not None: self._cache[expression_id] = sql @@ -770,25 +754,25 @@ class Generator: def bitstring_sql(self, expression: exp.BitString) -> str: this = self.sql(expression, "this") - if self.bit_start: - return f"{self.bit_start}{this}{self.bit_end}" + if self.BIT_START: + return f"{self.BIT_START}{this}{self.BIT_END}" return f"{int(this, 2)}" def hexstring_sql(self, expression: exp.HexString) -> str: this = self.sql(expression, "this") - if self.hex_start: - return f"{self.hex_start}{this}{self.hex_end}" + if self.HEX_START: + return f"{self.HEX_START}{this}{self.HEX_END}" return f"{int(this, 16)}" def bytestring_sql(self, expression: exp.ByteString) -> str: this = self.sql(expression, "this") - if self.byte_start: - return f"{self.byte_start}{this}{self.byte_end}" + if self.BYTE_START: + return f"{self.BYTE_START}{this}{self.BYTE_END}" return this def rawstring_sql(self, expression: exp.RawString) -> str: - if self.raw_start: - return f"{self.raw_start}{expression.name}{self.raw_end}" + if self.RAW_START: + return f"{self.RAW_START}{expression.name}{self.RAW_END}" return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\"))) def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: @@ -883,24 +867,27 @@ class Generator: name = f"{expression.name} " if expression.name else "" table = self.sql(expression, "table") table = f"{self.INDEX_ON} {table} " if table else "" + using = self.sql(expression, "using") + using = f"USING {using} " if using else "" index = "INDEX " if not table else "" columns = self.expressions(expression, key="columns", flat=True) + columns = f"({columns})" if columns else "" partition_by = self.expressions(expression, key="partition_by", flat=True) partition_by = f" PARTITION BY {partition_by}" if partition_by else "" - return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}" + return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name lower = text.lower() text = lower if self.normalize and not expression.quoted else text - text = text.replace(self.identifier_end, self._escaped_identifier_end) + text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end) if ( expression.quoted or should_identify(text, self.identify) or lower in self.RESERVED_KEYWORDS - or (not self.identifiers_can_start_with_digit and text[:1].isdigit()) + or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit()) ): - text = f"{self.identifier_start}{text}{self.identifier_end}" + text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}" return text def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: @@ -1197,7 +1184,7 @@ class Generator: def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: - if self.alias_post_tablesample and expression.this.alias: + if self.ALIAS_POST_TABLESAMPLE and expression.this.alias: table = expression.this.copy() table.set("alias", None) this = self.sql(table) @@ -1372,7 +1359,15 @@ class Generator: def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") - return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" + args = ", ".join( + sql + for sql in ( + self.sql(expression, "offset"), + self.sql(expression, "expression"), + ) + if sql + ) + return f"{this}{self.seg('LIMIT')} {args}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") @@ -1418,10 +1413,10 @@ class Generator: def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: - text = text.replace(self.quote_end, self._escaped_quote_end) + text = text.replace(self.QUOTE_END, self._escaped_quote_end) if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) - text = f"{self.quote_start}{text}{self.quote_end}" + text = f"{self.QUOTE_START}{text}{self.QUOTE_END}" return text def loaddata_sql(self, expression: exp.LoadData) -> str: @@ -1463,9 +1458,9 @@ class Generator: nulls_first = expression.args.get("nulls_first") nulls_last = not nulls_first - nulls_are_large = self.null_ordering == "nulls_are_large" - nulls_are_small = self.null_ordering == "nulls_are_small" - nulls_are_last = self.null_ordering == "nulls_are_last" + nulls_are_large = self.NULL_ORDERING == "nulls_are_large" + nulls_are_small = self.NULL_ORDERING == "nulls_are_small" + nulls_are_last = self.NULL_ORDERING == "nulls_are_last" sort_order = " DESC" if desc else "" nulls_sort_change = "" @@ -1521,7 +1516,7 @@ class Generator: return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: - limit = expression.args.get("limit") + limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit") if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): limit = exp.Limit(expression=limit.args.get("count")) @@ -1540,12 +1535,19 @@ class Generator: self.sql(expression, "having"), *self.after_having_modifiers(expression), self.sql(expression, "order"), - self.sql(expression, "offset") if fetch else self.sql(limit), - self.sql(limit) if fetch else self.sql(expression, "offset"), + *self.offset_limit_modifiers(expression, fetch, limit), *self.after_limit_modifiers(expression), sep="", ) + def offset_limit_modifiers( + self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] + ) -> t.List[str]: + return [ + self.sql(expression, "offset") if fetch else self.sql(limit), + self.sql(limit) if fetch else self.sql(expression, "offset"), + ] + def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: return [ self.sql(expression, "qualify"), @@ -1634,7 +1636,7 @@ class Generator: def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) alias = expression.args.get("alias") - if alias and self.unnest_column_only: + if alias and self.UNNEST_COLUMN_ONLY: columns = alias.columns alias = self.sql(columns[0]) if columns else "" else: @@ -1697,7 +1699,7 @@ class Generator: return f"{this} BETWEEN {low} AND {high}" def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset) + expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET) expressions_sql = ", ".join(self.sql(e) for e in expressions) return f"{self.sql(expression, 'this')}[{expressions_sql}]" @@ -1729,7 +1731,7 @@ class Generator: statements.append("END") - if self.pretty and self.text_width(statements) > self._max_text_width: + if self.pretty and self.text_width(statements) > self.max_text_width: return self.indent("\n".join(statements), skip_first=True, skip_last=True) return " ".join(statements) @@ -1759,10 +1761,11 @@ class Generator: else: return self.func("TRIM", expression.this, expression.expression) - def concat_sql(self, expression: exp.Concat) -> str: - if len(expression.expressions) == 1: - return self.sql(expression.expressions[0]) - return self.function_fallback_sql(expression) + def safeconcat_sql(self, expression: exp.SafeConcat) -> str: + expressions = expression.expressions + if self.STRICT_STRING_CONCAT: + expressions = (exp.cast(e, "text") for e in expressions) + return self.func("CONCAT", *expressions) def check_sql(self, expression: exp.Check) -> str: this = self.sql(expression, key="this") @@ -1785,9 +1788,7 @@ class Generator: return f"PRIMARY KEY ({expressions}){options}" def if_sql(self, expression: exp.If) -> str: - return self.case_sql( - exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) - ) + return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: modifier = expression.args.get("modifier") @@ -1798,7 +1799,6 @@ class Generator: return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}" def jsonobject_sql(self, expression: exp.JSONObject) -> str: - expressions = self.expressions(expression) null_handling = expression.args.get("null_handling") null_handling = f" {null_handling}" if null_handling else "" unique_keys = expression.args.get("unique_keys") @@ -1811,7 +1811,11 @@ class Generator: format_json = " FORMAT JSON" if expression.args.get("format_json") else "" encoding = self.sql(expression, "encoding") encoding = f" ENCODING {encoding}" if encoding else "" - return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})" + return self.func( + "JSON_OBJECT", + *expression.expressions, + suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})", + ) def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: this = self.sql(expression, "this") @@ -1930,7 +1934,7 @@ class Generator: for i, e in enumerate(expression.flatten(unnest=False)) ) - sep = "\n" if self.text_width(sqls) > self._max_text_width else " " + sep = "\n" if self.text_width(sqls) > self.max_text_width else " " return f"{sep}{op} ".join(sqls) def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: @@ -2093,6 +2097,11 @@ class Generator: def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") + def safedpipe_sql(self, expression: exp.SafeDPipe) -> str: + if self.STRICT_STRING_CONCAT: + return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten())) + return self.dpipe_sql(expression) + def div_sql(self, expression: exp.Div) -> str: return self.binary(expression, "/") @@ -2127,7 +2136,7 @@ class Generator: return self.binary(expression, "ILIKE ANY") def is_sql(self, expression: exp.Is) -> str: - if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean): + if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean): return self.sql( expression.this if expression.expression.this else exp.not_(expression.this) ) @@ -2197,12 +2206,18 @@ class Generator: return self.func(expression.sql_name(), *args) - def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str: - return f"{self.normalize_func(name)}({self.format_args(*args)})" + def func( + self, + name: str, + *args: t.Optional[exp.Expression | str], + prefix: str = "(", + suffix: str = ")", + ) -> str: + return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}" def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) - if self.pretty and self.text_width(arg_sqls) > self._max_text_width: + if self.pretty and self.text_width(arg_sqls) > self.max_text_width: return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) return ", ".join(arg_sqls) @@ -2210,7 +2225,9 @@ class Generator: return sum(len(arg) for arg in args) def format_time(self, expression: exp.Expression) -> t.Optional[str]: - return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) + return format_time( + self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE + ) def expressions( self, @@ -2242,7 +2259,7 @@ class Generator: comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" if self.pretty: - if self._leading_comma: + if self.leading_comma: result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}") else: result_sqls.append( diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 4215fee..2f48ab5 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> return expression -def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: +def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: """ Sorts a given directed acyclic graph in topological order. @@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]: """ result = [] - def visit(node: T, visited: t.Set[T]) -> None: - if node in result: - return - if node in visited: - raise ValueError("Cycle error") + for node, deps in tuple(dag.items()): + for dep in deps: + if not dep in dag: + dag[dep] = set() + + while dag: + current = {node for node, deps in dag.items() if not deps} - visited.add(node) + if not current: + raise ValueError("Cycle error") - for dep in dag.get(node, []): - visit(dep, visited) + for node in current: + dag.pop(node) - visited.remove(node) - result.append(node) + for deps in dag.values(): + deps -= current - for node in dag: - visit(node, set()) + result.extend(sorted(current)) # type: ignore return result diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 6238759..39e2c53 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,13 +1,25 @@ +from __future__ import annotations + +import typing as t + from sqlglot import exp +from sqlglot._typing import E from sqlglot.helper import ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema +from sqlglot.schema import Schema, ensure_schema + +if t.TYPE_CHECKING: + B = t.TypeVar("B", bound=exp.Binary) -def annotate_types(expression, schema=None, annotators=None, coerces_to=None): +def annotate_types( + expression: E, + schema: t.Optional[t.Dict | Schema] = None, + annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, +) -> E: """ - Recursively infer & annotate types in an expression syntax tree against a schema. - Assumes that we've already executed the optimizer's qualify_columns step. + Infers the types of an expression, annotating its AST accordingly. Example: >>> import sqlglot @@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): <Type.DOUBLE: 'DOUBLE'> Args: - expression (sqlglot.Expression): Expression to annotate. - schema (dict|sqlglot.optimizer.Schema): Database schema. - annotators (dict): Maps expression type to corresponding annotation function. - coerces_to (dict): Maps expression type to set of types that it can be coerced into. + expression: Expression to annotate. + schema: Database schema. + annotators: Maps expression type to corresponding annotation function. + coerces_to: Maps expression type to set of types that it can be coerced into. + Returns: - sqlglot.Expression: expression annotated with types + The expression annotated with types. """ schema = ensure_schema(schema) @@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) -class TypeAnnotator: - ANNOTATORS = { - **{ - expr_type: lambda self, expr: self._annotate_unary(expr) - for expr_type in subclasses(exp.__name__, exp.Unary) - }, - **{ - expr_type: lambda self, expr: self._annotate_binary(expr) - for expr_type in subclasses(exp.__name__, exp.Binary) - }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), - exp.Alias: lambda self, expr: self._annotate_unary(expr), - exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), - exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.BIGINT - ), - exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.Sum: lambda self, expr: self._annotate_by_args( - expr, "this", "expressions", promote=True - ), - exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.CurrentTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.TimestampSub: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATE - ), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), - exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), - exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), - exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.GroupConcat: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.ArrayConcat: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), - exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), - exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), - exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), - exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DOUBLE - ), - exp.RegexpLike: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.BOOLEAN - ), - exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATE - ), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DOUBLE - ), - exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - } +def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: + return lambda self, e: self._annotate_with_type(e, data_type) - # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html - COERCES_TO = { - # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT - exp.DataType.Type.TEXT: set(), - exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, - exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: { - exp.DataType.Type.VARCHAR, - exp.DataType.Type.NVARCHAR, + +class _TypeAnnotator(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): + # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html + text_precedence = ( exp.DataType.Type.TEXT, - }, - exp.DataType.Type.CHAR: { - exp.DataType.Type.NCHAR, - exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, - exp.DataType.Type.TEXT, - }, - # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE - exp.DataType.Type.DOUBLE: set(), - exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, - exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: { - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NCHAR, + exp.DataType.Type.CHAR, + ) + numeric_precedence = ( exp.DataType.Type.DOUBLE, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.BIGINT, + exp.DataType.Type.INT, + exp.DataType.Type.SMALLINT, + exp.DataType.Type.TINYINT, + ) + timelike_precedence = ( + exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.DATETIME, + exp.DataType.Type.DATE, + ) + + for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): + coerces_to = set() + for data_type in type_precedence: + klass.COERCES_TO[data_type] = coerces_to.copy() + coerces_to |= {data_type} + + return klass + + +class TypeAnnotator(metaclass=_TypeAnnotator): + TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { + exp.DataType.Type.BIGINT: { + exp.ApproxDistinct, + exp.ArraySize, + exp.Count, + exp.Length, + }, + exp.DataType.Type.BOOLEAN: { + exp.Between, + exp.Boolean, + exp.In, + exp.RegexpLike, + }, + exp.DataType.Type.DATE: { + exp.CurrentDate, + exp.Date, + exp.DateAdd, + exp.DateStrToDate, + exp.DateSub, + exp.DateTrunc, + exp.DiToDate, + exp.StrToDate, + exp.TimeStrToDate, + exp.TsOrDsToDate, + }, + exp.DataType.Type.DATETIME: { + exp.CurrentDatetime, + exp.DatetimeAdd, + exp.DatetimeSub, + }, + exp.DataType.Type.DOUBLE: { + exp.ApproxQuantile, + exp.Avg, + exp.Exp, + exp.Ln, + exp.Log, + exp.Log2, + exp.Log10, + exp.Pow, + exp.Quantile, + exp.Round, + exp.SafeDivide, + exp.Sqrt, + exp.Stddev, + exp.StddevPop, + exp.StddevSamp, + exp.Variance, + exp.VariancePop, }, exp.DataType.Type.INT: { - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.Ceil, + exp.DateDiff, + exp.DatetimeDiff, + exp.Extract, + exp.TimestampDiff, + exp.TimeDiff, + exp.DateToDi, + exp.Floor, + exp.Levenshtein, + exp.StrPosition, + exp.TsOrDiToDi, }, - exp.DataType.Type.SMALLINT: { - exp.DataType.Type.INT, - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.DataType.Type.TIMESTAMP: { + exp.CurrentTime, + exp.CurrentTimestamp, + exp.StrToTime, + exp.TimeAdd, + exp.TimeStrToTime, + exp.TimeSub, + exp.TimestampAdd, + exp.TimestampSub, + exp.UnixToTime, }, exp.DataType.Type.TINYINT: { - exp.DataType.Type.SMALLINT, - exp.DataType.Type.INT, - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.Day, + exp.Month, + exp.Week, + exp.Year, }, - # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ - exp.DataType.Type.TIMESTAMPLTZ: set(), - exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: { - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.VARCHAR: { + exp.ArrayConcat, + exp.Concat, + exp.ConcatWs, + exp.DateToDateStr, + exp.GroupConcat, + exp.Initcap, + exp.Lower, + exp.SafeConcat, + exp.Substring, + exp.TimeToStr, + exp.TimeToTimeStr, + exp.Trim, + exp.TsOrDsToDateStr, + exp.UnixToStr, + exp.UnixToTimeStr, + exp.Upper, }, - exp.DataType.Type.DATETIME: { - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + } + + ANNOTATORS = { + **{ + expr_type: lambda self, e: self._annotate_unary(e) + for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) }, - exp.DataType.Type.DATE: { - exp.DataType.Type.DATETIME, - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + **{ + expr_type: lambda self, e: self._annotate_binary(e) + for expr_type in subclasses(exp.__name__, exp.Binary) + }, + **{ + expr_type: _annotate_with_type_lambda(data_type) + for data_type, expressions in TYPE_TO_EXPRESSIONS.items() + for expr_type in expressions }, + exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), + exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), + exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), + exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Literal: lambda self, e: self._annotate_literal(e), + exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), + exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), + exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), } - TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + # Specifies what types a given type can be coerced into (autofilled) + COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} - def __init__(self, schema=None, annotators=None, coerces_to=None): + def __init__( + self, + schema: Schema, + annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + ) -> None: self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO - def annotate(self, expression): - if isinstance(expression, self.TRAVERSABLES): - for scope in traverse_scope(expression): - selects = {} - for name, source in scope.sources.items(): - if not isinstance(source, Scope): - continue - if isinstance(source.expression, exp.UDTF): - values = [] - - if isinstance(source.expression, exp.Lateral): - if isinstance(source.expression.this, exp.Explode): - values = [source.expression.this.this] - else: - values = source.expression.expressions[0].expressions - - if not values: - continue - - selects[name] = { - alias: column - for alias, column in zip( - source.expression.alias_column_names, - values, - ) - } + def annotate(self, expression: E) -> E: + for scope in traverse_scope(expression): + selects = {} + for name, source in scope.sources.items(): + if not isinstance(source, Scope): + continue + if isinstance(source.expression, exp.UDTF): + values = [] + + if isinstance(source.expression, exp.Lateral): + if isinstance(source.expression.this, exp.Explode): + values = [source.expression.this.this] else: - selects[name] = { - select.alias_or_name: select for select in source.expression.selects - } - # First annotate the current scope's column references - for col in scope.columns: - if not col.table: + values = source.expression.expressions[0].expressions + + if not values: continue - source = scope.sources.get(col.table) - if isinstance(source, exp.Table): - col.type = self.schema.get_column_type(source, col) - elif source and col.table in selects and col.name in selects[col.table]: - col.type = selects[col.table][col.name].type - # Then (possibly) annotate the remaining expressions in the scope - self._maybe_annotate(scope.expression) + selects[name] = { + alias: column + for alias, column in zip( + source.expression.alias_column_names, + values, + ) + } + else: + selects[name] = { + select.alias_or_name: select for select in source.expression.selects + } + + # First annotate the current scope's column references + for col in scope.columns: + if not col.table: + continue + + source = scope.sources.get(col.table) + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + elif source and col.table in selects and col.name in selects[col.table]: + col.type = selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + return self._maybe_annotate(expression) # This takes care of non-traversable expressions - def _maybe_annotate(self, expression): + def _maybe_annotate(self, expression: E) -> E: if expression.type: return expression # We've already inferred the expression's type @@ -312,13 +290,15 @@ class TypeAnnotator: else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) ) - def _annotate_args(self, expression): + def _annotate_args(self, expression: E) -> E: for _, value in expression.iter_expressions(): self._maybe_annotate(value) return expression - def _maybe_coerce(self, type1, type2): + def _maybe_coerce( + self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type + ) -> exp.DataType.Type: # We propagate the NULL / UNKNOWN types upwards if found if isinstance(type1, exp.DataType): type1 = type1.this @@ -330,9 +310,14 @@ class TypeAnnotator: if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to.get(type1, {}) else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore - def _annotate_binary(self, expression): + # Note: the following "no_type_check" decorators were added because mypy was yelling due + # to assigning Type values to expression.type (since its getter returns Optional[DataType]). + # This is a known mypy issue: https://github.com/python/mypy/issues/3004 + + @t.no_type_check + def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) left_type = expression.left.type.this @@ -354,7 +339,8 @@ class TypeAnnotator: return expression - def _annotate_unary(self, expression): + @t.no_type_check + def _annotate_unary(self, expression: E) -> E: self._annotate_args(expression) if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): @@ -364,7 +350,8 @@ class TypeAnnotator: return expression - def _annotate_literal(self, expression): + @t.no_type_check + def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: if expression.is_string: expression.type = exp.DataType.Type.VARCHAR elif expression.is_int: @@ -374,13 +361,16 @@ class TypeAnnotator: return expression - def _annotate_with_type(self, expression, target_type): + @t.no_type_check + def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args, promote=False): + @t.no_type_check + def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: self._annotate_args(expression) - expressions = [] + + expressions: t.List[exp.Expression] = [] for arg in args: arg_expr = expression.args.get(arg) expressions.extend(expr for expr in ensure_list(arg_expr) if expr) diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index da2fce8..015b06a 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: def add_text_to_concat(node: exp.Expression) -> exp.Expression: if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: - node = exp.Concat(this=node.this, expression=node.expression) + node = exp.Concat(expressions=[node.left, node.right]) return node diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 27de9c7..cd8ba3b 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -32,7 +32,7 @@ def eliminate_joins(expression): # Reverse the joins so we can remove chains of unused joins for join in reversed(joins): - alias = join.this.alias_or_name + alias = join.alias_or_name if _should_eliminate_join(scope, join, alias): join.pop() scope.remove_source(alias) @@ -126,7 +126,7 @@ def join_condition(join): tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate) """ - name = join.this.alias_or_name + name = join.alias_or_name on = (join.args.get("on") or exp.true()).copy() source_key = [] join_key = [] diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 5dfa4aa..79e3ed5 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None): source.replace( exp.select("*") .from_( - alias(source, source.name or source.alias, table=True), + alias(source, source.alias_or_name, table=True), copy=False, ) .subquery(source.alias, copy=False) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index f9c9664..fefe96e 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): if not isinstance(from_or_join, exp.Join): return False - alias = from_or_join.this.alias_or_name + alias = from_or_join.alias_or_name on = from_or_join.args.get("on") if not on: @@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join): """ new_joins = [] - comma_joins = inner_scope.expression.args.get("from").expressions[1:] - for subquery in comma_joins: - new_joins.append(exp.Join(this=subquery, kind="CROSS")) - outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) joins = inner_scope.expression.args.get("joins") or [] for join in joins: @@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if source == from_or_join.alias_or_name: break - if set(exp.column_table_names(where.this)) <= sources: + if exp.column_table_names(where.this) <= sources: from_or_join.on(where.this, copy=False) from_or_join.set("on", from_or_join.args.get("on")) return expression.where(where.this, copy=False) - expression.set("where", expression.args.get("where")) def _merge_order(outer_scope, inner_scope): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 4e0c3a1..d51276f 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + from sqlglot import exp from sqlglot.helper import tsort @@ -13,25 +17,28 @@ def optimize_joins(expression): >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' """ + for select in expression.find_all(exp.Select): references = {} cross_joins = [] for join in select.args.get("joins", []): - name = join.this.alias_or_name - tables = other_table_names(join, name) + tables = other_table_names(join) if tables: for table in tables: references[table] = references.get(table, []) + [join] else: - cross_joins.append((name, join)) + cross_joins.append((join.alias_or_name, join)) for name, join in cross_joins: for dep in references.get(name, []): on = dep.args["on"] if isinstance(on, exp.Connector): + if len(other_table_names(dep)) < 2: + continue + for predicate in on.flatten(): if name in exp.column_table_names(predicate): predicate.replace(exp.true()) @@ -47,17 +54,12 @@ def reorder_joins(expression): Reorder joins by topological sort order based on predicate references. """ for from_ in expression.find_all(exp.From): - head = from_.this parent = from_.parent - joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} - dag = {head.alias_or_name: []} - - for name, join in joins.items(): - dag[name] = other_table_names(join, name) - + joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {name: other_table_names(join) for name, join in joins.items()} parent.set( "joins", - [joins[name] for name in tsort(dag) if name != head.alias_or_name], + [joins[name] for name in tsort(dag) if name != from_.alias_or_name], ) return expression @@ -75,9 +77,6 @@ def normalize(expression): return expression -def other_table_names(join, exclude): - return [ - name - for name in (exp.column_table_names(join.args.get("on") or exp.true())) - if name != exclude - ] +def other_table_names(join: exp.Join) -> t.Set[str]: + on = join.args.get("on") + return exp.column_table_names(on, join.alias_or_name) if on else set() diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index dbe33a2..abac63b 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -78,7 +78,7 @@ def optimize( "schema": schema, "dialect": dialect, "isolate_tables": True, # needed for other optimizations to perform well - "quote_identifiers": False, # this happens in canonicalize + "quote_identifiers": False, **kwargs, } diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index b89a82b..fb1662d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -41,7 +41,7 @@ def pushdown_predicates(expression): # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself for join in select.args.get("joins") or []: - name = join.this.alias_or_name + name = join.alias_or_name pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression @@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count): pushdown_tables = set() for a in predicates: - a_tables = set(exp.column_table_names(a)) + a_tables = exp.column_table_names(a) for b in predicates: - a_tables &= set(exp.column_table_names(b)) + a_tables &= exp.column_table_names(b) pushdown_tables.update(a_tables) @@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): tables = exp.column_table_names(predicate) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) - for table in tables: + for table in sorted(tables): node, source = sources.get(table) or (None, None) # if the predicate is in a where statement we can try to push it down diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 4a31171..aba9a7e 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema def qualify_columns( expression: exp.Expression, - schema: dict | Schema, + schema: t.Dict | Schema, expand_alias_refs: bool = True, infer_schema: t.Optional[bool] = None, ) -> exp.Expression: @@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables): def _expand_using(scope, resolver): joins = list(scope.find_all(exp.Join)) - names = {join.this.alias for join in joins} + names = {join.alias_or_name for join in joins} ordered = [key for key in scope.selected_sources if key not in names] # Mapping of automatically joined column names to an ordered set of source names (dict). @@ -105,7 +105,7 @@ def _expand_using(scope, resolver): if not using: continue - join_table = join.this.alias_or_name + join_table = join.alias_or_name columns = {} diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index fcc5f26..9c931d6 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -91,11 +91,13 @@ def qualify_tables( ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression - table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name()) + table_alias = udtf.args.get("alias") or exp.TableAlias( + this=exp.to_identifier(next_alias_name()) + ) udtf.set("alias", table_alias) if not table_alias.name: - table_alias.set("this", next_alias_name()) + table_alias.set("this", exp.to_identifier(next_alias_name())) if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 9ffb4d6..aa56b83 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -620,7 +620,7 @@ def _traverse_tables(scope): table_name = expression.name source_name = expression.alias_or_name - if table_name in scope.sources: + if table_name in scope.sources and not expression.db: # This is a reference to a parent source (e.g. a CTE), not an actual table, unless # it is pivoted, because then we get back a new table and hence a new source. pivots = expression.args.get("pivots") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 96bd6e3..d6888c7 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -6,7 +6,8 @@ from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get +from sqlglot.helper import apply_index_offset, ensure_list, seq_get +from sqlglot.time import format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -25,13 +26,14 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: for i in range(0, len(args), 2): keys.append(args[i]) values.append(args[i + 1]) + return exp.VarMap( keys=exp.Array(expressions=keys), values=exp.Array(expressions=values), ) -def parse_like(args: t.List) -> exp.Expression: +def parse_like(args: t.List) -> exp.Escape | exp.Like: like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like @@ -47,33 +49,26 @@ def binary_range_parser( class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) - klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) - klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) + + klass.SHOW_TRIE = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) + klass.SET_TRIE = new_trie(key.split(" ") for key in klass.SET_PARSERS) return klass class Parser(metaclass=_Parser): """ - Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces - a parsed syntax tree. + Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree. Args: - error_level: the desired error level. + error_level: The desired error level. Default: ErrorLevel.IMMEDIATE - error_message_context: determines the amount of context to capture from a + error_message_context: Determines the amount of context to capture from a query string when displaying the error message (in number of characters). - Default: 50. - index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. - Default: 0 - alias_post_tablesample: If the table alias comes after tablesample. - Default: False + Default: 100 max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3 - null_ordering: Indicates the default null ordering method to use if not explicitly set. - Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". - Default: "nulls_are_small" """ FUNCTIONS: t.Dict[str, t.Callable] = { @@ -83,7 +78,6 @@ class Parser(metaclass=_Parser): to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), - "IFNULL": exp.Coalesce.from_arg_list, "LIKE": parse_like, "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), @@ -108,8 +102,6 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_USER: exp.CurrentUser, } - JOIN_HINTS: t.Set[str] = set() - NESTED_TYPE_TOKENS = { TokenType.ARRAY, TokenType.MAP, @@ -117,6 +109,10 @@ class Parser(metaclass=_Parser): TokenType.STRUCT, } + ENUM_TYPE_TOKENS = { + TokenType.ENUM, + } + TYPE_TOKENS = { TokenType.BIT, TokenType.BOOLEAN, @@ -188,6 +184,7 @@ class Parser(metaclass=_Parser): TokenType.VARIANT, TokenType.OBJECT, TokenType.INET, + TokenType.ENUM, *NESTED_TYPE_TOKENS, } @@ -198,7 +195,10 @@ class Parser(metaclass=_Parser): TokenType.SOME: exp.Any, } - RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT} + RESERVED_KEYWORDS = { + *Tokenizer.SINGLE_TOKENS.values(), + TokenType.SELECT, + } DB_CREATABLES = { TokenType.DATABASE, @@ -216,6 +216,7 @@ class Parser(metaclass=_Parser): *DB_CREATABLES, } + # Tokens that can represent identifiers ID_VAR_TOKENS = { TokenType.VAR, TokenType.ANTI, @@ -224,6 +225,7 @@ class Parser(metaclass=_Parser): TokenType.AUTO_INCREMENT, TokenType.BEGIN, TokenType.CACHE, + TokenType.CASE, TokenType.COLLATE, TokenType.COMMAND, TokenType.COMMENT, @@ -274,6 +276,7 @@ class Parser(metaclass=_Parser): TokenType.TRUE, TokenType.UNIQUE, TokenType.UNPIVOT, + TokenType.UPDATE, TokenType.VOLATILE, TokenType.WINDOW, *CREATABLES, @@ -409,6 +412,8 @@ class Parser(metaclass=_Parser): TokenType.ANTI, } + JOIN_HINTS: t.Set[str] = set() + LAMBDAS = { TokenType.ARROW: lambda self, expressions: self.expression( exp.Lambda, @@ -420,7 +425,7 @@ class Parser(metaclass=_Parser): ), TokenType.FARROW: lambda self, expressions: self.expression( exp.Kwarg, - this=exp.Var(this=expressions[0].name), + this=exp.var(expressions[0].name), expression=self._parse_conjunction(), ), } @@ -515,7 +520,7 @@ class Parser(metaclass=_Parser): TokenType.USE: lambda self: self.expression( exp.Use, kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA")) - and exp.Var(this=self._prev.text), + and exp.var(self._prev.text), this=self._parse_table(schema=False), ), } @@ -634,6 +639,7 @@ class Parser(metaclass=_Parser): "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), "TEMP": lambda self: self.expression(exp.TemporaryProperty), "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), + "TO": lambda self: self._parse_to_table(), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), "TTL": lambda self: self._parse_ttl(), "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), @@ -710,6 +716,7 @@ class Parser(metaclass=_Parser): FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "CONCAT": lambda self: self._parse_concat(), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), @@ -755,8 +762,11 @@ class Parser(metaclass=_Parser): MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) - TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN} + PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE} + + TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} TRANSACTION_CHARACTERISTICS = { "ISOLATION LEVEL REPEATABLE READ", "ISOLATION LEVEL READ COMMITTED", @@ -778,6 +788,8 @@ class Parser(metaclass=_Parser): STRICT_CAST = True + CONCAT_NULL_OUTPUTS_STRING = False # A NULL arg in CONCAT yields NULL by default + CONVERT_TYPE_FIRST = False PREFIXED_PIVOT_COLUMNS = False @@ -789,40 +801,39 @@ class Parser(metaclass=_Parser): __slots__ = ( "error_level", "error_message_context", + "max_errors", "sql", "errors", - "index_offset", - "unnest_column_only", - "alias_post_tablesample", - "max_errors", - "null_ordering", "_tokens", "_index", "_curr", "_next", "_prev", "_prev_comments", - "_show_trie", - "_set_trie", ) + # Autofilled + INDEX_OFFSET: int = 0 + UNNEST_COLUMN_ONLY: bool = False + ALIAS_POST_TABLESAMPLE: bool = False + STRICT_STRING_CONCAT = False + NULL_ORDERING: str = "nulls_are_small" + SHOW_TRIE: t.Dict = {} + SET_TRIE: t.Dict = {} + FORMAT_MAPPING: t.Dict[str, str] = {} + FORMAT_TRIE: t.Dict = {} + TIME_MAPPING: t.Dict[str, str] = {} + TIME_TRIE: t.Dict = {} + def __init__( self, error_level: t.Optional[ErrorLevel] = None, error_message_context: int = 100, - index_offset: int = 0, - unnest_column_only: bool = False, - alias_post_tablesample: bool = False, max_errors: int = 3, - null_ordering: t.Optional[str] = None, ): self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context - self.index_offset = index_offset - self.unnest_column_only = unnest_column_only - self.alias_post_tablesample = alias_post_tablesample self.max_errors = max_errors - self.null_ordering = null_ordering self.reset() def reset(self): @@ -843,11 +854,11 @@ class Parser(metaclass=_Parser): per parsed SQL statement. Args: - raw_tokens: the list of tokens. - sql: the original SQL string, used to produce helpful debug messages. + raw_tokens: The list of tokens. + sql: The original SQL string, used to produce helpful debug messages. Returns: - The list of syntax trees. + The list of the produced syntax trees. """ return self._parse( parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql @@ -865,23 +876,25 @@ class Parser(metaclass=_Parser): of them, stopping at the first for which the parsing succeeds. Args: - expression_types: the expression type(s) to try and parse the token list into. - raw_tokens: the list of tokens. - sql: the original SQL string, used to produce helpful debug messages. + expression_types: The expression type(s) to try and parse the token list into. + raw_tokens: The list of tokens. + sql: The original SQL string, used to produce helpful debug messages. Returns: The target Expression. """ errors = [] - for expression_type in ensure_collection(expression_types): + for expression_type in ensure_list(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: raise TypeError(f"No parser registered for {expression_type}") + try: return self._parse(parser, raw_tokens, sql) except ParseError as e: e.errors[0]["into_expression"] = expression_type errors.append(e) + raise ParseError( f"Failed to parse '{sql or raw_tokens}' into {expression_types}", errors=merge_errors(errors), @@ -895,6 +908,7 @@ class Parser(metaclass=_Parser): ) -> t.List[t.Optional[exp.Expression]]: self.reset() self.sql = sql or "" + total = len(raw_tokens) chunks: t.List[t.List[Token]] = [[]] @@ -922,9 +936,7 @@ class Parser(metaclass=_Parser): return expressions def check_errors(self) -> None: - """ - Logs or raises any found errors, depending on the chosen error level setting. - """ + """Logs or raises any found errors, depending on the chosen error level setting.""" if self.error_level == ErrorLevel.WARN: for error in self.errors: logger.error(str(error)) @@ -969,39 +981,38 @@ class Parser(metaclass=_Parser): Creates a new, validated Expression. Args: - exp_class: the expression class to instantiate. - comments: an optional list of comments to attach to the expression. - kwargs: the arguments to set for the expression along with their respective values. + exp_class: The expression class to instantiate. + comments: An optional list of comments to attach to the expression. + kwargs: The arguments to set for the expression along with their respective values. Returns: The target expression. """ instance = exp_class(**kwargs) instance.add_comments(comments) if comments else self._add_comments(instance) - self.validate_expression(instance) - return instance + return self.validate_expression(instance) def _add_comments(self, expression: t.Optional[exp.Expression]) -> None: if expression and self._prev_comments: expression.add_comments(self._prev_comments) self._prev_comments = None - def validate_expression( - self, expression: exp.Expression, args: t.Optional[t.List] = None - ) -> None: + def validate_expression(self, expression: E, args: t.Optional[t.List] = None) -> E: """ - Validates an already instantiated expression, making sure that all its mandatory arguments - are set. + Validates an Expression, making sure that all its mandatory arguments are set. Args: - expression: the expression to validate. - args: an optional list of items that was used to instantiate the expression, if it's a Func. + expression: The expression to validate. + args: An optional list of items that was used to instantiate the expression, if it's a Func. + + Returns: + The validated expression. """ - if self.error_level == ErrorLevel.IGNORE: - return + if self.error_level != ErrorLevel.IGNORE: + for error_message in expression.error_messages(args): + self.raise_error(error_message) - for error_message in expression.error_messages(args): - self.raise_error(error_message) + return expression def _find_sql(self, start: Token, end: Token) -> str: return self.sql[start.start : end.end + 1] @@ -1010,6 +1021,7 @@ class Parser(metaclass=_Parser): self._index += times self._curr = seq_get(self._tokens, self._index) self._next = seq_get(self._tokens, self._index + 1) + if self._index > 0: self._prev = self._tokens[self._index - 1] self._prev_comments = self._prev.comments @@ -1031,7 +1043,6 @@ class Parser(metaclass=_Parser): self._match(TokenType.ON) kind = self._match_set(self.CREATABLES) and self._prev - if not kind: return self._parse_as_command(start) @@ -1050,6 +1061,12 @@ class Parser(metaclass=_Parser): exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists ) + def _parse_to_table( + self, + ) -> exp.ToTableProperty: + table = self._parse_table_parts(schema=True) + return self.expression(exp.ToTableProperty, this=table) + # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl def _parse_ttl(self) -> exp.Expression: def _parse_ttl_action() -> t.Optional[exp.Expression]: @@ -1102,10 +1119,11 @@ class Parser(metaclass=_Parser): expression = self._parse_set_operations(expression) if expression else self._parse_select() return self._parse_query_modifiers(expression) - def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: + def _parse_drop(self) -> exp.Drop | exp.Command: start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match_text_seq("MATERIALIZED") + kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: return self._parse_as_command(start) @@ -1129,21 +1147,23 @@ class Parser(metaclass=_Parser): and self._match(TokenType.EXISTS) ) - def _parse_create(self) -> t.Optional[exp.Expression]: + def _parse_create(self) -> exp.Create | exp.Command: + # Note: this can't be None because we've matched a statement parser start = self._prev - replace = self._prev.text.upper() == "REPLACE" or self._match_pair( + replace = start.text.upper() == "REPLACE" or self._match_pair( TokenType.OR, TokenType.REPLACE ) unique = self._match(TokenType.UNIQUE) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): - self._match(TokenType.TABLE) + self._advance() properties = None create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - properties = self._parse_properties() # exp.Properties.Location.POST_CREATE + # exp.Properties.Location.POST_CREATE + properties = self._parse_properties() create_token = self._match_set(self.CREATABLES) and self._prev if not properties or not create_token: @@ -1157,7 +1177,7 @@ class Parser(metaclass=_Parser): begin = None clone = None - def extend_props(temp_props: t.Optional[exp.Expression]) -> None: + def extend_props(temp_props: t.Optional[exp.Properties]) -> None: nonlocal properties if properties and temp_props: properties.expressions.extend(temp_props.expressions) @@ -1166,6 +1186,8 @@ class Parser(metaclass=_Parser): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) + + # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature) extend_props(self._parse_properties()) self._match(TokenType.ALIAS) @@ -1190,13 +1212,8 @@ class Parser(metaclass=_Parser): extend_props(self._parse_properties()) self._match(TokenType.ALIAS) - - # exp.Properties.Location.POST_ALIAS - if not ( - self._match(TokenType.SELECT, advance=False) - or self._match(TokenType.WITH, advance=False) - or self._match(TokenType.L_PAREN, advance=False) - ): + if not self._match_set(self.DDL_SELECT_TOKENS, advance=False): + # exp.Properties.Location.POST_ALIAS extend_props(self._parse_properties()) expression = self._parse_ddl_select() @@ -1206,7 +1223,7 @@ class Parser(metaclass=_Parser): while True: index = self._parse_index() - # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX + # exp.Properties.Location.POST_EXPRESSION and POST_INDEX extend_props(self._parse_properties()) if not index: @@ -1296,7 +1313,7 @@ class Parser(metaclass=_Parser): return None - def _parse_stored(self) -> exp.Expression: + def _parse_stored(self) -> exp.FileFormatProperty: self._match(TokenType.ALIAS) input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None @@ -1311,14 +1328,13 @@ class Parser(metaclass=_Parser): else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) - def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: + def _parse_property_assignment(self, exp_class: t.Type[E]) -> E: self._match(TokenType.EQ) self._match(TokenType.ALIAS) return self.expression(exp_class, this=self._parse_field()) - def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]: + def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]: properties = [] - while True: if before: prop = self._parse_property_before() @@ -1335,29 +1351,25 @@ class Parser(metaclass=_Parser): return None - def _parse_fallback(self, no: bool = False) -> exp.Expression: + def _parse_fallback(self, no: bool = False) -> exp.FallbackProperty: return self.expression( exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") ) - def _parse_volatile_property(self) -> exp.Expression: + def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty: if self._index >= 2: pre_volatile_token = self._tokens[self._index - 2] else: pre_volatile_token = None - if pre_volatile_token and pre_volatile_token.token_type in ( - TokenType.CREATE, - TokenType.REPLACE, - TokenType.UNIQUE, - ): + if pre_volatile_token and pre_volatile_token.token_type in self.PRE_VOLATILE_TOKENS: return exp.VolatileProperty() return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) def _parse_with_property( self, - ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: + ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]: self._match(TokenType.WITH) if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_property) @@ -1376,7 +1388,7 @@ class Parser(metaclass=_Parser): return self._parse_withisolatedloading() # https://dev.mysql.com/doc/refman/8.0/en/create-view.html - def _parse_definer(self) -> t.Optional[exp.Expression]: + def _parse_definer(self) -> t.Optional[exp.DefinerProperty]: self._match(TokenType.EQ) user = self._parse_id_var() @@ -1388,18 +1400,18 @@ class Parser(metaclass=_Parser): return exp.DefinerProperty(this=f"{user}@{host}") - def _parse_withjournaltable(self) -> exp.Expression: + def _parse_withjournaltable(self) -> exp.WithJournalTableProperty: self._match(TokenType.TABLE) self._match(TokenType.EQ) return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) - def _parse_log(self, no: bool = False) -> exp.Expression: + def _parse_log(self, no: bool = False) -> exp.LogProperty: return self.expression(exp.LogProperty, no=no) - def _parse_journal(self, **kwargs) -> exp.Expression: + def _parse_journal(self, **kwargs) -> exp.JournalProperty: return self.expression(exp.JournalProperty, **kwargs) - def _parse_checksum(self) -> exp.Expression: + def _parse_checksum(self) -> exp.ChecksumProperty: self._match(TokenType.EQ) on = None @@ -1407,53 +1419,47 @@ class Parser(metaclass=_Parser): on = True elif self._match_text_seq("OFF"): on = False - default = self._match(TokenType.DEFAULT) - return self.expression( - exp.ChecksumProperty, - on=on, - default=default, - ) + return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT)) - def _parse_cluster(self) -> t.Optional[exp.Expression]: + def _parse_cluster(self) -> t.Optional[exp.Cluster]: if not self._match_text_seq("BY"): self._retreat(self._index - 1) return None - return self.expression( - exp.Cluster, - expressions=self._parse_csv(self._parse_ordered), - ) - def _parse_freespace(self) -> exp.Expression: + return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered)) + + def _parse_freespace(self) -> exp.FreespaceProperty: self._match(TokenType.EQ) return self.expression( exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT) ) - def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression: + def _parse_mergeblockratio( + self, no: bool = False, default: bool = False + ) -> exp.MergeBlockRatioProperty: if self._match(TokenType.EQ): return self.expression( exp.MergeBlockRatioProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT), ) - return self.expression( - exp.MergeBlockRatioProperty, - no=no, - default=default, - ) + + return self.expression(exp.MergeBlockRatioProperty, no=no, default=default) def _parse_datablocksize( self, default: t.Optional[bool] = None, minimum: t.Optional[bool] = None, maximum: t.Optional[bool] = None, - ) -> exp.Expression: + ) -> exp.DataBlocksizeProperty: self._match(TokenType.EQ) size = self._parse_number() + units = None if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): units = self._prev.text + return self.expression( exp.DataBlocksizeProperty, size=size, @@ -1463,12 +1469,13 @@ class Parser(metaclass=_Parser): maximum=maximum, ) - def _parse_blockcompression(self) -> exp.Expression: + def _parse_blockcompression(self) -> exp.BlockCompressionProperty: self._match(TokenType.EQ) always = self._match_text_seq("ALWAYS") manual = self._match_text_seq("MANUAL") never = self._match_text_seq("NEVER") default = self._match_text_seq("DEFAULT") + autotemp = None if self._match_text_seq("AUTOTEMP"): autotemp = self._parse_schema() @@ -1482,7 +1489,7 @@ class Parser(metaclass=_Parser): autotemp=autotemp, ) - def _parse_withisolatedloading(self) -> exp.Expression: + def _parse_withisolatedloading(self) -> exp.IsolatedLoadingProperty: no = self._match_text_seq("NO") concurrent = self._match_text_seq("CONCURRENT") self._match_text_seq("ISOLATED", "LOADING") @@ -1498,7 +1505,7 @@ class Parser(metaclass=_Parser): for_none=for_none, ) - def _parse_locking(self) -> exp.Expression: + def _parse_locking(self) -> exp.LockingProperty: if self._match(TokenType.TABLE): kind = "TABLE" elif self._match(TokenType.VIEW): @@ -1553,14 +1560,14 @@ class Parser(metaclass=_Parser): return self._parse_csv(self._parse_conjunction) return [] - def _parse_partitioned_by(self) -> exp.Expression: + def _parse_partitioned_by(self) -> exp.PartitionedByProperty: self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_withdata(self, no: bool = False) -> exp.Expression: + def _parse_withdata(self, no: bool = False) -> exp.WithDataProperty: if self._match_text_seq("AND", "STATISTICS"): statistics = True elif self._match_text_seq("AND", "NO", "STATISTICS"): @@ -1570,52 +1577,50 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - def _parse_no_property(self) -> t.Optional[exp.Property]: + def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]: if self._match_text_seq("PRIMARY", "INDEX"): return exp.NoPrimaryIndexProperty() return None - def _parse_on_property(self) -> t.Optional[exp.Property]: + def _parse_on_property(self) -> t.Optional[exp.Expression]: if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): return exp.OnCommitProperty() elif self._match_text_seq("COMMIT", "DELETE", "ROWS"): return exp.OnCommitProperty(delete=True) return None - def _parse_distkey(self) -> exp.Expression: + def _parse_distkey(self) -> exp.DistKeyProperty: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) - def _parse_create_like(self) -> t.Optional[exp.Expression]: + def _parse_create_like(self) -> t.Optional[exp.LikeProperty]: table = self._parse_table(schema=True) + options = [] while self._match_texts(("INCLUDING", "EXCLUDING")): this = self._prev.text.upper() - id_var = self._parse_id_var() + id_var = self._parse_id_var() if not id_var: return None options.append( - self.expression( - exp.Property, - this=this, - value=exp.Var(this=id_var.this.upper()), - ) + self.expression(exp.Property, this=this, value=exp.var(id_var.this.upper())) ) + return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_sortkey(self, compound: bool = False) -> exp.Expression: + def _parse_sortkey(self, compound: bool = False) -> exp.SortKeyProperty: return self.expression( - exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound + exp.SortKeyProperty, this=self._parse_wrapped_id_vars(), compound=compound ) - def _parse_character_set(self, default: bool = False) -> exp.Expression: + def _parse_character_set(self, default: bool = False) -> exp.CharacterSetProperty: self._match(TokenType.EQ) return self.expression( exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) - def _parse_returns(self) -> exp.Expression: + def _parse_returns(self) -> exp.ReturnsProperty: value: t.Optional[exp.Expression] is_table = self._match(TokenType.TABLE) @@ -1629,19 +1634,18 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema(exp.Var(this="TABLE")) + value = self._parse_schema(exp.var("TABLE")) else: value = self._parse_types() return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_describe(self) -> exp.Expression: + def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text this = self._parse_table() - return self.expression(exp.Describe, this=this, kind=kind) - def _parse_insert(self) -> exp.Expression: + def _parse_insert(self) -> exp.Insert: overwrite = self._match(TokenType.OVERWRITE) local = self._match_text_seq("LOCAL") alternative = None @@ -1673,11 +1677,11 @@ class Parser(metaclass=_Parser): alternative=alternative, ) - def _parse_on_conflict(self) -> t.Optional[exp.Expression]: + def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: conflict = self._match_text_seq("ON", "CONFLICT") duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") - if not (conflict or duplicate): + if not conflict and not duplicate: return None nothing = None @@ -1707,18 +1711,20 @@ class Parser(metaclass=_Parser): constraint=constraint, ) - def _parse_returning(self) -> t.Optional[exp.Expression]: + def _parse_returning(self) -> t.Optional[exp.Returning]: if not self._match(TokenType.RETURNING): return None return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column)) - def _parse_row(self) -> t.Optional[exp.Expression]: + def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: if not self._match(TokenType.FORMAT): return None return self._parse_row_format() - def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]: + def _parse_row_format( + self, match_row: bool = False + ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None @@ -1744,7 +1750,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore - def _parse_load(self) -> exp.Expression: + def _parse_load(self) -> exp.LoadData | exp.Command: if self._match_text_seq("DATA"): local = self._match_text_seq("LOCAL") self._match_text_seq("INPATH") @@ -1764,7 +1770,7 @@ class Parser(metaclass=_Parser): ) return self._parse_as_command(self._prev) - def _parse_delete(self) -> exp.Expression: + def _parse_delete(self) -> exp.Delete: self._match(TokenType.FROM) return self.expression( @@ -1775,7 +1781,7 @@ class Parser(metaclass=_Parser): returning=self._parse_returning(), ) - def _parse_update(self) -> exp.Expression: + def _parse_update(self) -> exp.Update: return self.expression( exp.Update, **{ # type: ignore @@ -1787,22 +1793,20 @@ class Parser(metaclass=_Parser): }, ) - def _parse_uncache(self) -> exp.Expression: + def _parse_uncache(self) -> exp.Uncache: if not self._match(TokenType.TABLE): self.raise_error("Expecting TABLE after UNCACHE") return self.expression( - exp.Uncache, - exists=self._parse_exists(), - this=self._parse_table(schema=True), + exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True) ) - def _parse_cache(self) -> exp.Expression: + def _parse_cache(self) -> exp.Cache: lazy = self._match_text_seq("LAZY") self._match(TokenType.TABLE) table = self._parse_table(schema=True) - options = [] + options = [] if self._match_text_seq("OPTIONS"): self._match_l_paren() k = self._parse_string() @@ -1820,7 +1824,7 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_partition(self) -> t.Optional[exp.Expression]: + def _parse_partition(self) -> t.Optional[exp.Partition]: if not self._match(TokenType.PARTITION): return None @@ -1828,7 +1832,7 @@ class Parser(metaclass=_Parser): exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction) ) - def _parse_value(self) -> exp.Expression: + def _parse_value(self) -> exp.Tuple: if self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_conjunction) self._match_r_paren() @@ -1926,7 +1930,7 @@ class Parser(metaclass=_Parser): return self._parse_set_operations(this) - def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]: if not skip_with_token and not self._match(TokenType.WITH): return None @@ -1946,22 +1950,19 @@ class Parser(metaclass=_Parser): exp.With, comments=comments, expressions=expressions, recursive=recursive ) - def _parse_cte(self) -> exp.Expression: + def _parse_cte(self) -> exp.CTE: alias = self._parse_table_alias() if not alias or not alias.this: self.raise_error("Expected CTE to have alias") self._match(TokenType.ALIAS) - return self.expression( - exp.CTE, - this=self._parse_wrapped(self._parse_statement), - alias=alias, + exp.CTE, this=self._parse_wrapped(self._parse_statement), alias=alias ) def _parse_table_alias( self, alias_tokens: t.Optional[t.Collection[TokenType]] = None - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.TableAlias]: any_token = self._match(TokenType.ALIAS) alias = ( self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) @@ -1982,9 +1983,10 @@ class Parser(metaclass=_Parser): def _parse_subquery( self, this: t.Optional[exp.Expression], parse_alias: bool = True - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.Subquery]: if not this: return None + return self.expression( exp.Subquery, this=this, @@ -2000,19 +2002,25 @@ class Parser(metaclass=_Parser): expression = parser(self) if expression: + if key == "limit": + offset = expression.args.pop("offset", None) + if offset: + this.set("offset", exp.Offset(expression=offset)) this.set(key, expression) return this - def _parse_hint(self) -> t.Optional[exp.Expression]: + def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) + if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") + return self.expression(exp.Hint, expressions=hints) return None - def _parse_into(self) -> t.Optional[exp.Expression]: + def _parse_into(self) -> t.Optional[exp.Into]: if not self._match(TokenType.INTO): return None @@ -2039,7 +2047,7 @@ class Parser(metaclass=_Parser): this=self._parse_query_modifiers(this) if modifiers else this, ) - def _parse_match_recognize(self) -> t.Optional[exp.Expression]: + def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: if not self._match(TokenType.MATCH_RECOGNIZE): return None @@ -2052,7 +2060,7 @@ class Parser(metaclass=_Parser): ) if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): - rows = exp.Var(this="ONE ROW PER MATCH") + rows = exp.var("ONE ROW PER MATCH") elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): text = "ALL ROWS PER MATCH" if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): @@ -2061,7 +2069,7 @@ class Parser(metaclass=_Parser): text += f" OMIT EMPTY MATCHES" elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): text += f" WITH UNMATCHED ROWS" - rows = exp.Var(this=text) + rows = exp.var(text) else: rows = None @@ -2075,7 +2083,7 @@ class Parser(metaclass=_Parser): text += f" TO FIRST {self._advance_any().text}" # type: ignore elif self._match_text_seq("TO", "LAST"): text += f" TO LAST {self._advance_any().text}" # type: ignore - after = exp.Var(this=text) + after = exp.var(text) else: after = None @@ -2093,11 +2101,14 @@ class Parser(metaclass=_Parser): paren += 1 if self._curr.token_type == TokenType.R_PAREN: paren -= 1 + end = self._prev self._advance() + if paren > 0: self.raise_error("Expecting )", self._curr) - pattern = exp.Var(this=self._find_sql(start, end)) + + pattern = exp.var(self._find_sql(start, end)) else: pattern = None @@ -2127,7 +2138,7 @@ class Parser(metaclass=_Parser): alias=self._parse_table_alias(), ) - def _parse_lateral(self) -> t.Optional[exp.Expression]: + def _parse_lateral(self) -> t.Optional[exp.Lateral]: outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) @@ -2150,24 +2161,19 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) - table_alias: t.Optional[exp.Expression] - if view: table = self._parse_id_var(any_token=False) columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] - table_alias = self.expression(exp.TableAlias, this=table, columns=columns) + table_alias: t.Optional[exp.TableAlias] = self.expression( + exp.TableAlias, this=table, columns=columns + ) + elif isinstance(this, exp.Subquery) and this.alias: + # Ensures parity between the Subquery's and the Lateral's "alias" args + table_alias = this.args["alias"].copy() else: table_alias = self._parse_table_alias() - expression = self.expression( - exp.Lateral, - this=this, - view=view, - outer=outer, - alias=table_alias, - ) - - return expression + return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias) def _parse_join_parts( self, @@ -2178,7 +2184,7 @@ class Parser(metaclass=_Parser): self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]: if self._match(TokenType.COMMA): return self.expression(exp.Join, this=self._parse_table()) @@ -2223,7 +2229,7 @@ class Parser(metaclass=_Parser): def _parse_index( self, index: t.Optional[exp.Expression] = None, - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.Index]: if index: unique = None primary = None @@ -2236,11 +2242,15 @@ class Parser(metaclass=_Parser): unique = self._match(TokenType.UNIQUE) primary = self._match_text_seq("PRIMARY") amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): return None + index = self._parse_id_var() table = None + using = self._parse_field() if self._match(TokenType.USING) else None + if self._match(TokenType.L_PAREN, advance=False): columns = self._parse_wrapped_csv(self._parse_ordered) else: @@ -2250,6 +2260,7 @@ class Parser(metaclass=_Parser): exp.Index, this=index, table=table, + using=using, columns=columns, unique=unique, primary=primary, @@ -2259,7 +2270,7 @@ class Parser(metaclass=_Parser): def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: return ( - (not schema and self._parse_function()) + (not schema and self._parse_function(optional_parens=False)) or self._parse_id_var(any_token=False) or self._parse_string_as_identifier() or self._parse_placeholder() @@ -2314,7 +2325,7 @@ class Parser(metaclass=_Parser): if schema: return self._parse_schema(this=this) - if self.alias_post_tablesample: + if self.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) @@ -2331,7 +2342,7 @@ class Parser(metaclass=_Parser): ) self._match_r_paren() - if not self.alias_post_tablesample: + if not self.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() if table_sample: @@ -2340,46 +2351,47 @@ class Parser(metaclass=_Parser): return this - def _parse_unnest(self) -> t.Optional[exp.Expression]: + def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: if not self._match(TokenType.UNNEST): return None expressions = self._parse_wrapped_csv(self._parse_type) ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) - alias = self._parse_table_alias() - if alias and self.unnest_column_only: + alias = self._parse_table_alias() if with_alias else None + + if alias and self.UNNEST_COLUMN_ONLY: if alias.args.get("columns"): self.raise_error("Unexpected extra column alias in unnest.") + alias.set("columns", [alias.this]) alias.set("this", None) offset = None if self._match_pair(TokenType.WITH, TokenType.OFFSET): self._match(TokenType.ALIAS) - offset = self._parse_id_var() or exp.Identifier(this="offset") + offset = self._parse_id_var() or exp.to_identifier("offset") return self.expression( - exp.Unnest, - expressions=expressions, - ordinality=ordinality, - alias=alias, - offset=offset, + exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset ) - def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: + def _parse_derived_table_values(self) -> t.Optional[exp.Values]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) if not is_derived and not self._match(TokenType.VALUES): return None expressions = self._parse_csv(self._parse_value) + alias = self._parse_table_alias() if is_derived: self._match_r_paren() - return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) + return self.expression( + exp.Values, expressions=expressions, alias=alias or self._parse_table_alias() + ) - def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expression]: + def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]: if not self._match(TokenType.TABLE_SAMPLE) and not ( as_modifier and self._match_text_seq("USING", "SAMPLE") ): @@ -2456,7 +2468,7 @@ class Parser(metaclass=_Parser): exp.Pivot, this=this, expressions=expressions, using=using, group=group ) - def _parse_pivot(self) -> t.Optional[exp.Expression]: + def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index if self._match(TokenType.PIVOT): @@ -2519,7 +2531,7 @@ class Parser(metaclass=_Parser): def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: return [agg.alias for agg in aggregations] - def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]: if not skip_where_token and not self._match(TokenType.WHERE): return None @@ -2527,7 +2539,7 @@ class Parser(metaclass=_Parser): exp.Where, comments=self._prev_comments, this=self._parse_conjunction() ) - def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]: if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None @@ -2578,12 +2590,12 @@ class Parser(metaclass=_Parser): return self._parse_column() - def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]: if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) - def _parse_qualify(self) -> t.Optional[exp.Expression]: + def _parse_qualify(self) -> t.Optional[exp.Qualify]: if not self._match(TokenType.QUALIFY): return None return self.expression(exp.Qualify, this=self._parse_conjunction()) @@ -2598,16 +2610,15 @@ class Parser(metaclass=_Parser): exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) ) - def _parse_sort( - self, exp_class: t.Type[exp.Expression], *texts: str - ) -> t.Optional[exp.Expression]: + def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]: if not self._match_text_seq(*texts): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self) -> exp.Expression: + def _parse_ordered(self) -> exp.Ordered: this = self._parse_conjunction() self._match(TokenType.ASC) + is_desc = self._match(TokenType.DESC) is_nulls_first = self._match_text_seq("NULLS", "FIRST") is_nulls_last = self._match_text_seq("NULLS", "LAST") @@ -2615,13 +2626,14 @@ class Parser(metaclass=_Parser): asc = not desc nulls_first = is_nulls_first or False explicitly_null_ordered = is_nulls_first or is_nulls_last + if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") - or (desc and self.null_ordering != "nulls_are_small") + (asc and self.NULL_ORDERING == "nulls_are_small") + or (desc and self.NULL_ORDERING != "nulls_are_small") ) - and self.null_ordering != "nulls_are_last" + and self.NULL_ORDERING != "nulls_are_last" ): nulls_first = True @@ -2632,9 +2644,15 @@ class Parser(metaclass=_Parser): ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) - limit_exp = self.expression( - exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term() - ) + expression = self._parse_number() if top else self._parse_term() + + if self._match(TokenType.COMMA): + offset = expression + expression = self._parse_term() + else: + offset = None + + limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset) if limit_paren: self._match_r_paren() @@ -2667,17 +2685,15 @@ class Parser(metaclass=_Parser): return this def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): + if not self._match(TokenType.OFFSET): return this count = self._parse_number() self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_locks(self) -> t.List[exp.Expression]: - # Lists are invariant, so we need to use a type hint here - locks: t.List[exp.Expression] = [] - + def _parse_locks(self) -> t.List[exp.Lock]: + locks = [] while True: if self._match_text_seq("FOR", "UPDATE"): update = True @@ -2768,6 +2784,7 @@ class Parser(metaclass=_Parser): def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: index = self._index - 1 negate = self._match(TokenType.NOT) + if self._match_text_seq("DISTINCT", "FROM"): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) @@ -2781,7 +2798,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Not, this=this) if negate else this def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In: - unnest = self._parse_unnest() + unnest = self._parse_unnest(with_alias=False) if unnest: this = self.expression(exp.In, this=this, unnest=unnest) elif self._match(TokenType.L_PAREN): @@ -2798,7 +2815,7 @@ class Parser(metaclass=_Parser): return this - def _parse_between(self, this: exp.Expression) -> exp.Expression: + def _parse_between(self, this: exp.Expression) -> exp.Between: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() @@ -2809,7 +2826,7 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_interval(self) -> t.Optional[exp.Expression]: + def _parse_interval(self) -> t.Optional[exp.Interval]: if not self._match(TokenType.INTERVAL): return None @@ -2840,9 +2857,7 @@ class Parser(metaclass=_Parser): while True: if self._match_set(self.BITWISE): this = self.expression( - self.BITWISE[self._prev.token_type], - this=this, - expression=self._parse_term(), + self.BITWISE[self._prev.token_type], this=this, expression=self._parse_term() ) elif self._match_pair(TokenType.LT, TokenType.LT): this = self.expression( @@ -2890,7 +2905,7 @@ class Parser(metaclass=_Parser): return this - def _parse_type_size(self) -> t.Optional[exp.Expression]: + def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]: this = self._parse_type() if not this: return None @@ -2926,6 +2941,8 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv( lambda: self._parse_types(check_func=check_func, schema=schema) ) + elif type_token in self.ENUM_TYPE_TOKENS: + expressions = self._parse_csv(self._parse_primary) else: expressions = self._parse_csv(self._parse_type_size) @@ -2943,11 +2960,7 @@ class Parser(metaclass=_Parser): ) while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType( - this=exp.DataType.Type.ARRAY, - expressions=[this], - nested=True, - ) + this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) return this @@ -2973,23 +2986,14 @@ class Parser(metaclass=_Parser): value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: - if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ: + if self._match_text_seq("WITH", "TIME", "ZONE"): + maybe_func = False value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) - elif ( - self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE") - or type_token == TokenType.TIMESTAMPLTZ - ): + elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"): + maybe_func = False value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): - if type_token == TokenType.TIME: - value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) - else: - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) - - maybe_func = maybe_func and value is None - - if value is None: - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + maybe_func = False elif type_token == TokenType.INTERVAL: unit = self._parse_var() @@ -3037,7 +3041,7 @@ class Parser(metaclass=_Parser): return self._parse_bracket(this) return self._parse_column_ops(this) - def _parse_column_ops(self, this: exp.Expression) -> exp.Expression: + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) while self._match_set(self.COLUMN_OPERATORS): @@ -3057,7 +3061,7 @@ class Parser(metaclass=_Parser): else exp.Literal.string(value) ) else: - field = self._parse_field(anonymous_func=True) + field = self._parse_field(anonymous_func=True, any_token=True) if isinstance(field, exp.Func): # bigquery allows function calls like x.y.count(...) @@ -3089,8 +3093,10 @@ class Parser(metaclass=_Parser): expressions = [primary] while self._match(TokenType.STRING): expressions.append(exp.Literal.string(self._prev.text)) + if len(expressions) > 1: return self.expression(exp.Concat, expressions=expressions) + return primary if self._match_pair(TokenType.DOT, TokenType.NUMBER): @@ -3118,8 +3124,8 @@ class Parser(metaclass=_Parser): if this: this.add_comments(comments) - self._match_r_paren(expression=this) + self._match_r_paren(expression=this) return this return None @@ -3137,18 +3143,21 @@ class Parser(metaclass=_Parser): ) def _parse_function( - self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, ) -> t.Optional[exp.Expression]: if not self._curr: return None token_type = self._curr.token_type - if self._match_set(self.NO_PAREN_FUNCTION_PARSERS): + if optional_parens and self._match_set(self.NO_PAREN_FUNCTION_PARSERS): return self.NO_PAREN_FUNCTION_PARSERS[token_type](self) if not self._next or self._next.token_type != TokenType.L_PAREN: - if token_type in self.NO_PAREN_FUNCTIONS: + if optional_parens and token_type in self.NO_PAREN_FUNCTIONS: self._advance() return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) @@ -3182,8 +3191,7 @@ class Parser(metaclass=_Parser): args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) if function and not anonymous: - this = function(args) - self.validate_expression(this, args) + this = self.validate_expression(function(args), args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -3210,14 +3218,14 @@ class Parser(metaclass=_Parser): exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True ) - def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: + def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier: literal = self._parse_primary() if literal: return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) - def _parse_session_parameter(self) -> exp.Expression: + def _parse_session_parameter(self) -> exp.SessionParameter: kind = None this = self._parse_id_var() or self._parse_primary() @@ -3255,7 +3263,7 @@ class Parser(metaclass=_Parser): if isinstance(this, exp.EQ): left = this.this if isinstance(left, exp.Column): - left.replace(exp.Var(this=left.text("this"))) + left.replace(exp.var(left.text("this"))) return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this))) @@ -3279,6 +3287,7 @@ class Parser(metaclass=_Parser): lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(any_token=True)) ) + self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -3286,6 +3295,7 @@ class Parser(metaclass=_Parser): # column defs are not really columns, they're identifiers if isinstance(this, exp.Column): this = this.this + kind = self._parse_types(schema=True) if self._match_text_seq("FOR", "ORDINALITY"): @@ -3303,7 +3313,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - def _parse_auto_increment(self) -> exp.Expression: + def _parse_auto_increment( + self, + ) -> exp.GeneratedAsIdentityColumnConstraint | exp.AutoIncrementColumnConstraint: start = None increment = None @@ -3321,7 +3333,7 @@ class Parser(metaclass=_Parser): return exp.AutoIncrementColumnConstraint() - def _parse_compress(self) -> exp.Expression: + def _parse_compress(self) -> exp.CompressColumnConstraint: if self._match(TokenType.L_PAREN, advance=False): return self.expression( exp.CompressColumnConstraint, this=self._parse_wrapped_csv(self._parse_bitwise) @@ -3329,7 +3341,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) - def _parse_generated_as_identity(self) -> exp.Expression: + def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint: if self._match_text_seq("BY", "DEFAULT"): on_null = self._match_pair(TokenType.ON, TokenType.NULL) this = self.expression( @@ -3364,11 +3376,13 @@ class Parser(metaclass=_Parser): return this - def _parse_inline(self) -> t.Optional[exp.Expression]: + def _parse_inline(self) -> exp.InlineLengthColumnConstraint: self._match_text_seq("LENGTH") return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise()) - def _parse_not_constraint(self) -> t.Optional[exp.Expression]: + def _parse_not_constraint( + self, + ) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]: if self._match_text_seq("NULL"): return self.expression(exp.NotNullColumnConstraint) if self._match_text_seq("CASESPECIFIC"): @@ -3417,7 +3431,7 @@ class Parser(metaclass=_Parser): return self.CONSTRAINT_PARSERS[constraint](self) - def _parse_unique(self) -> exp.Expression: + def _parse_unique(self) -> exp.UniqueColumnConstraint: self._match_text_seq("KEY") return self.expression( exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)) @@ -3460,7 +3474,7 @@ class Parser(metaclass=_Parser): return options - def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]: + def _parse_references(self, match: bool = True) -> t.Optional[exp.Reference]: if match and not self._match(TokenType.REFERENCES): return None @@ -3473,7 +3487,7 @@ class Parser(metaclass=_Parser): options = self._parse_key_constraint_options() return self.expression(exp.Reference, this=this, expressions=expressions, options=options) - def _parse_foreign_key(self) -> exp.Expression: + def _parse_foreign_key(self) -> exp.ForeignKey: expressions = self._parse_wrapped_id_vars() reference = self._parse_references() options = {} @@ -3501,7 +3515,7 @@ class Parser(metaclass=_Parser): def _parse_primary_key( self, wrapped_optional: bool = False, in_props: bool = False - ) -> exp.Expression: + ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: desc = ( self._match_set((TokenType.ASC, TokenType.DESC)) and self._prev.token_type == TokenType.DESC @@ -3514,15 +3528,7 @@ class Parser(metaclass=_Parser): options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) - @t.overload - def _parse_bracket(self, this: exp.Expression) -> exp.Expression: - ... - - @t.overload def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - ... - - def _parse_bracket(self, this): if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this @@ -3541,7 +3547,7 @@ class Parser(metaclass=_Parser): elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: - expressions = apply_index_offset(this, expressions, -self.index_offset) + expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET) this = self.expression(exp.Bracket, this=this, expressions=expressions) if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: @@ -3582,8 +3588,7 @@ class Parser(metaclass=_Parser): def _parse_if(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): args = self._parse_csv(self._parse_conjunction) - this = exp.If.from_arg_list(args) - self.validate_expression(this, args) + this = self.validate_expression(exp.If.from_arg_list(args), args) self._match_r_paren() else: index = self._index - 1 @@ -3601,7 +3606,7 @@ class Parser(metaclass=_Parser): return self._parse_window(this) - def _parse_extract(self) -> exp.Expression: + def _parse_extract(self) -> exp.Extract: this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): @@ -3630,9 +3635,37 @@ class Parser(metaclass=_Parser): elif to.this == exp.DataType.Type.CHAR: if self._match(TokenType.CHARACTER_SET): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) + elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT): + fmt = self._parse_string() + + return self.expression( + exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, + this=this, + format=exp.Literal.string( + format_time( + fmt.this if fmt else "", + self.FORMAT_MAPPING or self.TIME_MAPPING, + self.FORMAT_TRIE or self.TIME_TRIE, + ) + ), + ) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_concat(self) -> t.Optional[exp.Expression]: + args = self._parse_csv(self._parse_conjunction) + if self.CONCAT_NULL_OUTPUTS_STRING: + args = [exp.func("COALESCE", arg, exp.Literal.string("")) for arg in args] + + # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when + # we find such a call we replace it with its argument. + if len(args) == 1: + return args[0] + + return self.expression( + exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args + ) + def _parse_string_agg(self) -> exp.Expression: expression: t.Optional[exp.Expression] @@ -3654,9 +3687,7 @@ class Parser(metaclass=_Parser): # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. if not self._match_text_seq("WITHIN", "GROUP"): self._retreat(index) - this = exp.GroupConcat.from_arg_list(args) - self.validate_expression(this, args) - return this + return self.validate_expression(exp.GroupConcat.from_arg_list(args), args) self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller) order = self._parse_order(this=expression) @@ -3679,7 +3710,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_decode(self) -> t.Optional[exp.Expression]: + def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]: """ There are generally two variants of the DECODE function: @@ -3726,18 +3757,20 @@ class Parser(metaclass=_Parser): return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None) - def _parse_json_key_value(self) -> t.Optional[exp.Expression]: + def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: self._match_text_seq("KEY") key = self._parse_field() self._match(TokenType.COLON) self._match_text_seq("VALUE") value = self._parse_field() + if not key and not value: return None return self.expression(exp.JSONKeyValue, this=key, expression=value) - def _parse_json_object(self) -> exp.Expression: - expressions = self._parse_csv(self._parse_json_key_value) + def _parse_json_object(self) -> exp.JSONObject: + star = self._parse_star() + expressions = [star] if star else self._parse_csv(self._parse_json_key_value) null_handling = None if self._match_text_seq("NULL", "ON", "NULL"): @@ -3767,7 +3800,7 @@ class Parser(metaclass=_Parser): encoding=encoding, ) - def _parse_logarithm(self) -> exp.Expression: + def _parse_logarithm(self) -> exp.Func: # Default argument order is base, expression args = self._parse_csv(self._parse_range) @@ -3780,7 +3813,7 @@ class Parser(metaclass=_Parser): exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) ) - def _parse_match_against(self) -> exp.Expression: + def _parse_match_against(self) -> exp.MatchAgainst: expressions = self._parse_csv(self._parse_column) self._match_text_seq(")", "AGAINST", "(") @@ -3803,15 +3836,16 @@ class Parser(metaclass=_Parser): ) # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16 - def _parse_open_json(self) -> exp.Expression: + def _parse_open_json(self) -> exp.OpenJSON: this = self._parse_bitwise() path = self._match(TokenType.COMMA) and self._parse_string() - def _parse_open_json_column_def() -> exp.Expression: + def _parse_open_json_column_def() -> exp.OpenJSONColumnDef: this = self._parse_field(any_token=True) kind = self._parse_types() path = self._parse_string() as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON) + return self.expression( exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json ) @@ -3823,7 +3857,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions) - def _parse_position(self, haystack_first: bool = False) -> exp.Expression: + def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): @@ -3838,17 +3872,15 @@ class Parser(metaclass=_Parser): needle = seq_get(args, 0) haystack = seq_get(args, 1) - this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2)) - - self.validate_expression(this, args) - - return this + return self.expression( + exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2) + ) - def _parse_join_hint(self, func_name: str) -> exp.Expression: + def _parse_join_hint(self, func_name: str) -> exp.JoinHint: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) - def _parse_substring(self) -> exp.Expression: + def _parse_substring(self) -> exp.Substring: # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 @@ -3859,12 +3891,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.FOR): args.append(self._parse_bitwise()) - this = exp.Substring.from_arg_list(args) - self.validate_expression(this, args) - - return this + return self.validate_expression(exp.Substring.from_arg_list(args), args) - def _parse_trim(self) -> exp.Expression: + def _parse_trim(self) -> exp.Trim: # https://www.w3resource.com/sql/character-functions/trim.php # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html @@ -3885,11 +3914,7 @@ class Parser(metaclass=_Parser): collation = self._parse_bitwise() return self.expression( - exp.Trim, - this=this, - position=position, - expression=expression, - collation=collation, + exp.Trim, this=this, position=position, expression=expression, collation=collation ) def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: @@ -4047,7 +4072,7 @@ class Parser(metaclass=_Parser): return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() - def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]: + def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True) def _parse_number(self) -> t.Optional[exp.Expression]: @@ -4097,7 +4122,7 @@ class Parser(metaclass=_Parser): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None - def _parse_parameter(self) -> exp.Expression: + def _parse_parameter(self) -> exp.Parameter: wrapped = self._match(TokenType.L_BRACE) this = self._parse_var() or self._parse_identifier() or self._parse_primary() self._match(TokenType.R_BRACE) @@ -4183,7 +4208,7 @@ class Parser(metaclass=_Parser): self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) - def _parse_transaction(self) -> exp.Expression: + def _parse_transaction(self) -> exp.Transaction: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text @@ -4203,7 +4228,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) - def _parse_commit_or_rollback(self) -> exp.Expression: + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -4220,6 +4245,7 @@ class Parser(metaclass=_Parser): if is_rollback: return self.expression(exp.Rollback, savepoint=savepoint) + return self.expression(exp.Commit, chain=chain) def _parse_add_column(self) -> t.Optional[exp.Expression]: @@ -4243,19 +4269,19 @@ class Parser(metaclass=_Parser): return expression - def _parse_drop_column(self) -> t.Optional[exp.Expression]: + def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: drop = self._match(TokenType.DROP) and self._parse_drop() if drop and not isinstance(drop, exp.Command): drop.set("kind", drop.args.get("kind", "COLUMN")) return drop # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html - def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression: + def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.DropPartition: return self.expression( exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists ) - def _parse_add_constraint(self) -> t.Optional[exp.Expression]: + def _parse_add_constraint(self) -> exp.AddConstraint: this = None kind = self._prev.token_type @@ -4288,7 +4314,7 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_csv(self._parse_add_column) - def _parse_alter_table_alter(self) -> exp.Expression: + def _parse_alter_table_alter(self) -> exp.AlterColumn: self._match(TokenType.COLUMN) column = self._parse_field(any_token=True) @@ -4316,11 +4342,11 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_csv(self._parse_drop_column) - def _parse_alter_table_rename(self) -> exp.Expression: + def _parse_alter_table_rename(self) -> exp.RenameTable: self._match_text_seq("TO") return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) - def _parse_alter(self) -> t.Optional[exp.Expression]: + def _parse_alter(self) -> exp.AlterTable | exp.Command: start = self._prev if not self._match(TokenType.TABLE): @@ -4345,7 +4371,7 @@ class Parser(metaclass=_Parser): ) return self._parse_as_command(start) - def _parse_merge(self) -> exp.Expression: + def _parse_merge(self) -> exp.Merge: self._match(TokenType.INTO) target = self._parse_table() @@ -4412,7 +4438,7 @@ class Parser(metaclass=_Parser): ) def _parse_show(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore + parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) if parser: return parser(self) self._advance() @@ -4433,17 +4459,9 @@ class Parser(metaclass=_Parser): return None right = self._parse_statement() or self._parse_id_var() - this = self.expression( - exp.EQ, - this=left, - expression=right, - ) + this = self.expression(exp.EQ, this=left, expression=right) - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) + return self.expression(exp.SetItem, this=this, kind=kind) def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: self._match_text_seq("TRANSACTION") @@ -4458,10 +4476,10 @@ class Parser(metaclass=_Parser): ) def _parse_set_item(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore + parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE) return parser(self) if parser else self._parse_set_item_assignment(kind=None) - def _parse_set(self) -> exp.Expression: + def _parse_set(self) -> exp.Set | exp.Command: index = self._index set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) @@ -4471,10 +4489,10 @@ class Parser(metaclass=_Parser): return set_ - def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]: + def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]: for option in options: if self._match_text_seq(*option.split(" ")): - return exp.Var(this=option) + return exp.var(option) return None def _parse_as_command(self, start: Token) -> exp.Command: diff --git a/sqlglot/planner.py b/sqlglot/planner.py index eccad35..4ed7449 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -302,7 +302,7 @@ class Join(Step): for join in joins: source_key, join_key, condition = join_condition(join) - step.joins[join.this.alias_or_name] = { + step.joins[join.alias_or_name] = { "side": join.side, # type: ignore "join_key": join_key, "source_key": source_key, diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f1c4a09..f73adee 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -285,8 +285,6 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): elif isinstance(column_type, str): return self._to_data_type(column_type.upper(), dialect=dialect) - raise SchemaError(f"Unknown column type '{column_type}'") - return exp.DataType.build("unknown") def _normalize(self, schema: t.Dict) -> t.Dict: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index a30ec24..42628b9 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -144,6 +144,7 @@ class TokenType(AutoName): VARIANT = auto() OBJECT = auto() INET = auto() + ENUM = auto() # keywords ALIAS = auto() @@ -346,6 +347,7 @@ class Token: col: The column that the token ends on. start: The start index of the token. end: The ending index of the token. + comments: The comments to attach to the token. """ self.token_type = token_type self.text = text @@ -391,12 +393,15 @@ class _Tokenizer(type): klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES) - klass._COMMENTS = dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) - for comment in klass.COMMENTS - ) + klass._COMMENTS = { + **dict( + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) + for comment in klass.COMMENTS + ), + "{#": "#}", # Ensure Jinja comments are tokenized correctly in all dialects + } - klass.KEYWORD_TRIE = new_trie( + klass._KEYWORD_TRIE = new_trie( key.upper() for key in ( *klass.KEYWORDS, @@ -456,20 +461,22 @@ class Tokenizer(metaclass=_Tokenizer): STRING_ESCAPES = ["'"] VAR_SINGLE_TOKENS: t.Set[str] = set() + # Autofilled + IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False + _COMMENTS: t.Dict[str, str] = {} _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} _IDENTIFIERS: t.Dict[str, str] = {} _IDENTIFIER_ESCAPES: t.Set[str] = set() _QUOTES: t.Dict[str, str] = {} _STRING_ESCAPES: t.Set[str] = set() + _KEYWORD_TRIE: t.Dict = {} - KEYWORDS: t.Dict[t.Optional[str], TokenType] = { + KEYWORDS: t.Dict[str, TokenType] = { **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")}, - "{{+": TokenType.BLOCK_START, - "{{-": TokenType.BLOCK_START, - "+}}": TokenType.BLOCK_END, - "-}}": TokenType.BLOCK_END, + **{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")}, + **{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")}, "/*+": TokenType.HINT, "==": TokenType.EQ, "::": TokenType.DCOLON, @@ -594,6 +601,7 @@ class Tokenizer(metaclass=_Tokenizer): "RECURSIVE": TokenType.RECURSIVE, "REGEXP": TokenType.RLIKE, "REPLACE": TokenType.REPLACE, + "RETURNING": TokenType.RETURNING, "REFERENCES": TokenType.REFERENCES, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, @@ -732,8 +740,7 @@ class Tokenizer(metaclass=_Tokenizer): NUMERIC_LITERALS: t.Dict[str, str] = {} ENCODE: t.Optional[str] = None - COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")] - KEYWORD_TRIE: t.Dict = {} # autofilled + COMMENTS = ["--", ("/*", "*/")] __slots__ = ( "sql", @@ -748,7 +755,6 @@ class Tokenizer(metaclass=_Tokenizer): "_end", "_peek", "_prev_token_line", - "identifiers_can_start_with_digit", ) def __init__(self) -> None: @@ -894,7 +900,7 @@ class Tokenizer(metaclass=_Tokenizer): char = chars prev_space = False skip = False - trie = self.KEYWORD_TRIE + trie = self._KEYWORD_TRIE single_token = char in self.SINGLE_TOKENS while chars: @@ -994,7 +1000,7 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() elif self._peek == "." and not decimal: after = self.peek(1) - if after.isdigit() or not after.strip(): + if after.isdigit() or not after.isalpha(): decimal = True self._advance() else: @@ -1013,13 +1019,13 @@ class Tokenizer(metaclass=_Tokenizer): literal += self._peek.upper() self._advance() - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, "")) if token_type: self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") return self._add(token_type, literal) - elif self.identifiers_can_start_with_digit: # type: ignore + elif self.IDENTIFIERS_CAN_START_WITH_DIGIT: return self._add(TokenType.VAR) self._add(TokenType.NUMBER, number_text) |