From 8d36f5966675e23bee7026ba37ae0647fbf47300 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Apr 2024 10:11:53 +0200 Subject: Merging upstream version 23.7.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/__init__.py | 2 + sqlglot/dialects/athena.py | 37 ++++++++ sqlglot/dialects/bigquery.py | 54 ++++++----- sqlglot/dialects/clickhouse.py | 77 ++++++++++++++-- sqlglot/dialects/dialect.py | 115 ++++++++++++++++++----- sqlglot/dialects/doris.py | 14 +-- sqlglot/dialects/drill.py | 16 +--- sqlglot/dialects/duckdb.py | 42 +++++++-- sqlglot/dialects/hive.py | 7 +- sqlglot/dialects/mysql.py | 82 +++++++++++------ sqlglot/dialects/oracle.py | 2 + sqlglot/dialects/postgres.py | 12 +++ sqlglot/dialects/presto.py | 40 +++++--- sqlglot/dialects/prql.py | 109 ++++++++++++++++++++++ sqlglot/dialects/redshift.py | 22 +---- sqlglot/dialects/snowflake.py | 201 ++++++++++++++++++++++++++++------------- sqlglot/dialects/spark.py | 11 ++- sqlglot/dialects/spark2.py | 9 +- sqlglot/dialects/sqlite.py | 12 +++ sqlglot/dialects/starrocks.py | 7 +- sqlglot/dialects/tableau.py | 2 + sqlglot/dialects/teradata.py | 9 +- sqlglot/dialects/trino.py | 1 + sqlglot/dialects/tsql.py | 24 ++++- 24 files changed, 685 insertions(+), 222 deletions(-) create mode 100644 sqlglot/dialects/athena.py create mode 100644 sqlglot/dialects/prql.py (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 82552c9..29c6580 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -61,6 +61,7 @@ dialect implementations in order to understand how their various components can ---- """ +from sqlglot.dialects.athena import Athena from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks @@ -73,6 +74,7 @@ from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.oracle import Oracle from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.presto import Presto +from sqlglot.dialects.prql import PRQL from sqlglot.dialects.redshift import Redshift from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.spark import Spark diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py new file mode 100644 index 0000000..f2deec8 --- /dev/null +++ b/sqlglot/dialects/athena.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.trino import Trino +from sqlglot.tokens import TokenType + + +class Athena(Trino): + class Parser(Trino.Parser): + STATEMENT_PARSERS = { + **Trino.Parser.STATEMENT_PARSERS, + TokenType.USING: lambda self: self._parse_as_command(self._prev), + } + + class Generator(Trino.Generator): + PROPERTIES_LOCATION = { + **Trino.Generator.PROPERTIES_LOCATION, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, + } + + TYPE_MAPPING = { + **Trino.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + } + + TRANSFORMS = { + **Trino.Generator.TRANSFORMS, + exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}", + } + + def property_sql(self, expression: exp.Property) -> str: + return ( + f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" + ) + + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, prefix=self.seg("TBLPROPERTIES")) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5bfc3ea..2167ba2 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -24,6 +24,7 @@ from sqlglot.dialects.dialect import ( rename_func, timestrtotime_sql, ts_or_ds_add_cast, + unit_to_var, ) from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType @@ -41,14 +42,22 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va structs = [] alias = expression.args.get("alias") for tup in expression.find_all(exp.Tuple): - field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions))) + field_aliases = ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(tup.expressions))) + ) expressions = [ exp.PropertyEQ(this=exp.to_identifier(name), expression=fld) for name, fld in zip(field_aliases, tup.expressions) ] structs.append(exp.Struct(expressions=expressions)) - return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)])) + # Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression + alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None + return self.unnest_sql( + exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only) + ) def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: @@ -190,7 +199,7 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str: expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True)) expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True)) - unit = expression.args.get("unit") or "DAY" + unit = unit_to_var(expression) return self.func("DATE_DIFF", expression.this, expression.expression, unit) @@ -238,16 +247,6 @@ class BigQuery(Dialect): "%E6S": "%S.%f", } - ESCAPE_SEQUENCES = { - "\\a": "\a", - "\\b": "\b", - "\\f": "\f", - "\\n": "\n", - "\\r": "\r", - "\\t": "\t", - "\\v": "\v", - } - FORMAT_MAPPING = { "DD": "%d", "MM": "%m", @@ -315,6 +314,7 @@ class BigQuery(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "BYTES": TokenType.BINARY, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "DATETIME": TokenType.TIMESTAMP, "DECLARE": TokenType.COMMAND, "ELSEIF": TokenType.COMMAND, "EXCEPTION": TokenType.COMMAND, @@ -486,14 +486,14 @@ class BigQuery(Dialect): table.set("db", exp.Identifier(this=parts[0])) table.set("this", exp.Identifier(this=parts[1])) - if isinstance(table.this, exp.Identifier) and "." in table.name: + if any("." in p.name for p in table.parts): catalog, db, this, *rest = ( - t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True)) - for x in split_num_words(table.name, ".", 3) + exp.to_identifier(p, quoted=True) + for p in split_num_words(".".join(p.name for p in table.parts), ".", 3) ) if rest and this: - this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest])) + this = exp.Dot.build([this, *rest]) # type: ignore table = exp.Table(this=this, db=db, catalog=catalog) table.meta["quoted_table"] = True @@ -527,7 +527,9 @@ class BigQuery(Dialect): return json_object - def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: bracket = super()._parse_bracket(this) if this is bracket: @@ -566,6 +568,7 @@ class BigQuery(Dialect): IGNORE_NULLS_IN_FUNC = True JSON_PATH_SINGLE_QUOTE_ESCAPE = True CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False NAMED_PLACEHOLDER_TOKEN = "@" TRANSFORMS = { @@ -588,7 +591,7 @@ class BigQuery(Dialect): exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", e.this, e.expression, e.unit or "DAY" + "DATE_DIFF", e.this, e.expression, unit_to_var(e) ), exp.DateFromParts: rename_func("DATE"), exp.DateStrToDate: datestrtodate_sql, @@ -607,6 +610,7 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), exp.Max: max_or_greatest, + exp.Mod: rename_func("MOD"), exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), exp.Min: min_or_least, @@ -847,10 +851,10 @@ class BigQuery(Dialect): return inline_array_sql(self, expression) def bracket_sql(self, expression: exp.Bracket) -> str: - this = self.sql(expression, "this") + this = expression.this expressions = expression.expressions - if len(expressions) == 1: + if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT): arg = expressions[0] if arg.type is None: from sqlglot.optimizer.annotate_types import annotate_types @@ -858,10 +862,10 @@ class BigQuery(Dialect): arg = annotate_types(arg) if arg.type and arg.type.this in exp.DataType.TEXT_TYPES: - # BQ doesn't support bracket syntax with string values - return f"{this}.{arg.name}" + # BQ doesn't support bracket syntax with string values for structs + return f"{self.sql(this)}.{arg.name}" - expressions_sql = ", ".join(self.sql(e) for e in expressions) + expressions_sql = self.expressions(expression, flat=True) offset = expression.args.get("offset") if offset == 0: @@ -874,7 +878,7 @@ class BigQuery(Dialect): if expression.args.get("safe"): expressions_sql = f"SAFE_{expressions_sql}" - return f"{this}[{expressions_sql}]" + return f"{self.sql(this)}[{expressions_sql}]" def in_unnest_op(self, expression: exp.Unnest) -> str: return self.sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 90167f6..631dc30 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -15,7 +15,6 @@ from sqlglot.dialects.dialect import ( rename_func, var_map_sql, ) -from sqlglot.errors import ParseError from sqlglot.helper import is_int, seq_get from sqlglot.tokens import Token, TokenType @@ -49,8 +48,9 @@ class ClickHouse(Dialect): NULL_ORDERING = "nulls_are_last" SUPPORTS_USER_DEFINED_TYPES = False SAFE_DIVISION = True + LOG_BASE_FIRST: t.Optional[bool] = None - ESCAPE_SEQUENCES = { + UNESCAPED_SEQUENCES = { "\\0": "\0", } @@ -105,6 +105,7 @@ class ClickHouse(Dialect): # * select x from t1 union all select x from t2 limit 1; # * select x from t1 union all (select x from t2 limit 1); MODIFIERS_ATTACHED_TO_UNION = False + INTERVAL_SPANS = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -260,6 +261,11 @@ class ClickHouse(Dialect): "ArgMax", ] + FUNC_TOKENS = { + *parser.Parser.FUNC_TOKENS, + TokenType.SET, + } + AGG_FUNC_MAPPING = ( lambda functions, suffixes: { f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions @@ -305,6 +311,10 @@ class ClickHouse(Dialect): TokenType.SETTINGS, } + ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - { + TokenType.FORMAT, + } + LOG_DEFAULTS_TO_LN = True QUERY_MODIFIER_PARSERS = { @@ -316,6 +326,17 @@ class ClickHouse(Dialect): TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()), } + CONSTRAINT_PARSERS = { + **parser.Parser.CONSTRAINT_PARSERS, + "INDEX": lambda self: self._parse_index_constraint(), + "CODEC": lambda self: self._parse_compress(), + } + + SCHEMA_UNNAMED_CONSTRAINTS = { + *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, + "INDEX", + } + def _parse_conjunction(self) -> t.Optional[exp.Expression]: this = super()._parse_conjunction() @@ -381,21 +402,20 @@ class ClickHouse(Dialect): # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ def _parse_cte(self) -> exp.CTE: - index = self._index - try: - # WITH AS - return super()._parse_cte() - except ParseError: - # WITH AS - self._retreat(index) + # WITH AS + cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte) - return self.expression( + if not cte: + # WITH AS + cte = self.expression( exp.CTE, this=self._parse_conjunction(), alias=self._parse_table_alias(), scalar=True, ) + return cte + def _parse_join_parts( self, ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: @@ -508,6 +528,27 @@ class ClickHouse(Dialect): self._retreat(index) return None + def _parse_index_constraint( + self, kind: t.Optional[str] = None + ) -> exp.IndexColumnConstraint: + # INDEX name1 expr TYPE type1(args) GRANULARITY value + this = self._parse_id_var() + expression = self._parse_conjunction() + + index_type = self._match_text_seq("TYPE") and ( + self._parse_function() or self._parse_var() + ) + + granularity = self._match_text_seq("GRANULARITY") and self._parse_term() + + return self.expression( + exp.IndexColumnConstraint, + this=this, + expression=expression, + index_type=index_type, + granularity=granularity, + ) + class Generator(generator.Generator): QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") @@ -517,6 +558,7 @@ class ClickHouse(Dialect): TABLESAMPLE_KEYWORDS = "SAMPLE" LAST_DAY_SUPPORTS_DATE_PART = False CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -585,6 +627,9 @@ class ClickHouse(Dialect): exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), exp.CountIf: rename_func("countIf"), + exp.CompressColumnConstraint: lambda self, + e: f"CODEC({self.expressions(e, key='this', flat=True)})", + exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}", exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), exp.DateAdd: date_delta_sql("DATE_ADD"), exp.DateDiff: date_delta_sql("DATE_DIFF"), @@ -737,3 +782,15 @@ class ClickHouse(Dialect): def prewhere_sql(self, expression: exp.PreWhere) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('PREWHERE')}{self.sep()}{this}" + + def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + expr = self.sql(expression, "expression") + expr = f" {expr}" if expr else "" + index_type = self.sql(expression, "index_type") + index_type = f" TYPE {index_type}" if index_type else "" + granularity = self.sql(expression, "granularity") + granularity = f" GRANULARITY {granularity}" if granularity else "" + + return f"INDEX{this}{expr}{index_type}{granularity}" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 599505c..81057c2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -31,6 +31,7 @@ class Dialects(str, Enum): DIALECT = "" + ATHENA = "athena" BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" DATABRICKS = "databricks" @@ -42,6 +43,7 @@ class Dialects(str, Enum): ORACLE = "oracle" POSTGRES = "postgres" PRESTO = "presto" + PRQL = "prql" REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SPARK = "spark" @@ -108,11 +110,18 @@ class _Dialect(type): klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) - klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} + base = seq_get(bases, 0) + base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) + base_parser = (getattr(base, "parser_class", Parser),) + base_generator = (getattr(base, "generator_class", Generator),) - klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) - klass.parser_class = getattr(klass, "Parser", Parser) - klass.generator_class = getattr(klass, "Generator", Generator) + klass.tokenizer_class = klass.__dict__.get( + "Tokenizer", type("Tokenizer", base_tokenizer, {}) + ) + klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) + klass.generator_class = klass.__dict__.get( + "Generator", type("Generator", base_generator, {}) + ) klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( @@ -134,9 +143,31 @@ class _Dialect(type): klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) + if "\\" in klass.tokenizer_class.STRING_ESCAPES: + klass.UNESCAPED_SEQUENCES = { + "\\a": "\a", + "\\b": "\b", + "\\f": "\f", + "\\n": "\n", + "\\r": "\r", + "\\t": "\t", + "\\v": "\v", + "\\\\": "\\", + **klass.UNESCAPED_SEQUENCES, + } + + klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} + if enum not in ("", "bigquery"): klass.generator_class.SELECT_KINDS = () + if enum not in ("", "databricks", "hive", "spark", "spark2"): + modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() + for modifier in ("cluster", "distribute", "sort"): + modifier_transforms.pop(modifier, None) + + klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms + if not klass.SUPPORTS_SEMI_ANTI_JOIN: klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { TokenType.ANTI, @@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect): False: Disables function name normalization. """ - LOG_BASE_FIRST = True - """Whether the base comes first in the `LOG` function.""" + LOG_BASE_FIRST: t.Optional[bool] = True + """ + Whether the base comes first in the `LOG` function. + Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) + """ NULL_ORDERING = "nulls_are_small" """ @@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect): If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. """ - ESCAPE_SEQUENCES: t.Dict[str, str] = {} - """Mapping of an unescaped escape sequence to the corresponding character.""" + UNESCAPED_SEQUENCES: t.Dict[str, str] = {} + """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" PSEUDOCOLUMNS: t.Set[str] = set() """ @@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect): INVERSE_TIME_MAPPING: t.Dict[str, str] = {} INVERSE_TIME_TRIE: t.Dict = {} - INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} + ESCAPED_SEQUENCES: t.Dict[str, str] = {} # Delimiters for string literals and identifiers QUOTE_START = "'" @@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> return "" -def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: +def str_position_sql( + self: Generator, expression: exp.StrPosition, generate_instance: bool = False +) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") position = self.sql(expression, "position") + instance = expression.args.get("instance") if generate_instance else None + position_offset = "" + if position: - return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" - return f"STRPOS({this}, {substr})" + # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects + this = self.func("SUBSTR", this, position) + position_offset = f" + {position} - 1" + + return self.func("STRPOS", this, substr, instance) + position_offset def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: @@ -689,9 +731,7 @@ def build_date_delta_with_interval( if expression and expression.is_string: expression = exp.Literal.number(expression.this) - return expression_class( - this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) - ) + return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) return _builder @@ -710,18 +750,14 @@ def date_add_interval_sql( ) -> t.Callable[[Generator, exp.Expression], str]: def func(self: Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") - unit = expression.args.get("unit") - unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression, unit=unit) + interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: - return self.func( - "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this - ) + return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: @@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return self.func( name, - exp.var(expression.text("unit").upper() or "DAY"), + unit_to_var(expression), expression.expression, expression.this, ) @@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return _delta_sql +def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, exp.Placeholder): + return unit + if unit: + return exp.Literal.string(unit.name) + return exp.Literal.string(default) if default else None + + +def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, (exp.Var, exp.Placeholder)): + return unit + return exp.Var(this=default) if default else None + + def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: trunc_curr_date = exp.func("date_trunc", "month", expression.this) plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") @@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: def build_json_extract_path( - expr_type: t.Type[F], zero_based_indexing: bool = True + expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False ) -> t.Callable[[t.List], F]: def _builder(args: t.List) -> F: segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] @@ -1018,7 +1072,11 @@ def build_json_extract_path( # This is done to avoid failing in the expression validator due to the arg count del args[2:] - return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) + return expr_type( + this=seq_get(args, 0), + expression=exp.JSONPath(expressions=segments), + only_json_types=arrow_req_json_type, + ) return _builder @@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s unnest = exp.Unnest(expressions=[expression.this]) filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) return self.sql(exp.Array(expressions=[filtered])) + + +def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("nlsparam"), + ) diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 9a84848..f4ec0e5 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( build_timestamp_trunc, rename_func, time_format, + unit_to_str, ) from sqlglot.dialects.mysql import MySQL @@ -27,7 +28,7 @@ class Doris(MySQL): } class Generator(MySQL.Generator): - CAST_MAPPING = {} + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, @@ -36,8 +37,7 @@ class Doris(MySQL): exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } - LAST_DAY_SUPPORTS_DATE_PART = False - + CAST_MAPPING = {} TIMESTAMP_FUNC_TYPES = set() TRANSFORMS = { @@ -49,9 +49,7 @@ class Doris(MySQL): exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.CurrentTimestamp: lambda self, _: self.func("NOW"), - exp.DateTrunc: lambda self, e: self.func( - "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" - ), + exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.Map: rename_func("ARRAY_MAP"), @@ -63,9 +61,7 @@ class Doris(MySQL): exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression), exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" - ), + exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)), exp.UnixToStr: lambda self, e: self.func( "FROM_UNIXTIME", e.this, time_format("doris")(self, e) ), diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index c1f6afa..0a00d92 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( @@ -12,18 +11,10 @@ from sqlglot.dialects.dialect import ( str_position_sql, timestrtotime_sql, ) +from sqlglot.dialects.mysql import date_add_sql from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by -def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - this = self.sql(expression, "this") - unit = exp.var(expression.text("unit").upper() or "DAY") - return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit)) - - return func - - def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) @@ -84,7 +75,6 @@ class Drill(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"), } @@ -124,9 +114,9 @@ class Drill(Dialect): exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), exp.Create: preprocess([move_schema_columns_to_partitioned_by]), - exp.DateAdd: _date_add_sql("ADD"), + exp.DateAdd: date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("SUB"), + exp.DateSub: date_add_sql("SUB"), exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", exp.DiToDate: lambda self, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f74dc97..6a1d07a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -26,6 +26,7 @@ from sqlglot.dialects.dialect import ( str_to_time_sql, timestamptrunc_sql, timestrtotime_sql, + unit_to_var, ) from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -33,15 +34,16 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") - unit = self.sql(expression, "unit").strip("'") or "DAY" - interval = self.sql(exp.Interval(this=expression.expression, unit=unit)) + interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression))) return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}" -def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_delta_sql( + self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd +) -> str: this = self.sql(expression, "this") - unit = self.sql(expression, "unit").strip("'") or "DAY" - op = "+" if isinstance(expression, exp.DateAdd) else "-" + unit = unit_to_var(expression) + op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-" return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" @@ -186,6 +188,11 @@ class DuckDB(Dialect): return super().to_json_path(path) class Tokenizer(tokens.Tokenizer): + HEREDOC_STRINGS = ["$"] + + HEREDOC_TAG_IS_IDENTIFIER = True + HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "//": TokenType.DIV, @@ -199,6 +206,7 @@ class DuckDB(Dialect): "LOGICAL": TokenType.BOOLEAN, "ONLY": TokenType.ONLY, "PIVOT_WIDER": TokenType.PIVOT, + "POSITIONAL": TokenType.POSITIONAL, "SIGNED": TokenType.INT, "STRING": TokenType.VARCHAR, "UBIGINT": TokenType.UBIGINT, @@ -227,8 +235,8 @@ class DuckDB(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_HAS": exp.ArrayContains.from_arg_list, - "ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_REVERSE_SORT": _build_sort_array_desc, + "ARRAY_SORT": exp.SortArray.from_arg_list, "DATEDIFF": _build_date_diff, "DATE_DIFF": _build_date_diff, "DATE_TRUNC": date_trunc_to_time, @@ -285,6 +293,11 @@ class DuckDB(Dialect): FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("DECODE") + NO_PAREN_FUNCTION_PARSERS = { + **parser.Parser.NO_PAREN_FUNCTION_PARSERS, + "MAP": lambda self: self._parse_map(), + } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.SEMI, TokenType.ANTI, @@ -299,6 +312,13 @@ class DuckDB(Dialect): ), } + def _parse_map(self) -> exp.ToMap | exp.Map: + if self._match(TokenType.L_BRACE, advance=False): + return self.expression(exp.ToMap, this=self._parse_bracket()) + + args = self._parse_wrapped_csv(self._parse_conjunction) + return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1)) + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -345,6 +365,7 @@ class DuckDB(Dialect): SUPPORTS_CREATE_TABLE_LIKE = False MULTI_ARG_DISTINCT = False CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -425,6 +446,7 @@ class DuckDB(Dialect): "EPOCH", self.func("STRPTIME", e.this, self.format_time(e)) ), exp.Struct: _struct_sql, + exp.TimeAdd: _date_delta_sql, exp.Timestamp: no_timestamp_sql, exp.TimestampDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this @@ -478,7 +500,7 @@ class DuckDB(Dialect): STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} - UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren) + UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren) # DuckDB doesn't generally support CREATE TABLE .. properties # https://duckdb.org/docs/sql/statements/create_table.html @@ -569,3 +591,9 @@ class DuckDB(Dialect): return rename_func("RANGE")(self, expression) return super().generateseries_sql(expression) + + def bracket_sql(self, expression: exp.Bracket) -> str: + if isinstance(expression.this, exp.Array): + expression.this.replace(exp.paren(expression.this)) + + return super().bracket_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 55a9254..cc7debb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -319,7 +319,9 @@ class Hive(Dialect): "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, "UNBASE64": exp.FromBase64.from_arg_list, - "UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True), + "UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)( + args or [exp.CurrentTimestamp()] + ), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } @@ -431,6 +433,7 @@ class Hive(Dialect): NVL2_SUPPORTED = False LAST_DAY_SUPPORTS_DATE_PART = False JSON_PATH_SINGLE_QUOTE_ESCAPE = True + SUPPORTS_TO_NUMBER = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Insert, @@ -472,7 +475,7 @@ class Hive(Dialect): exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArrayConcat: rename_func("CONCAT"), - exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), + exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort_sql, exp.With: no_recursive_cte_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 6ebae1e..1d53346 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import ( build_date_delta_with_interval, rename_func, strposition_to_locate_sql, + unit_to_var, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql( +def date_add_sql( kind: str, -) -> t.Callable[[MySQL.Generator, exp.Expression], str]: - def func(self: MySQL.Generator, expression: exp.Expression) -> str: - this = self.sql(expression, "this") - unit = expression.text("unit").upper() or "DAY" - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" +) -> t.Callable[[generator.Generator, exp.Expression], str]: + def func(self: generator.Generator, expression: exp.Expression) -> str: + return self.func( + f"DATE_{kind}", + expression.this, + exp.Interval(this=expression.expression, unit=unit_to_var(expression)), ) return func @@ -291,6 +292,7 @@ class MySQL(Dialect): "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"), "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, "MAKETIME": exp.TimeFromParts.from_arg_list, @@ -319,11 +321,7 @@ class MySQL(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, "CHAR": lambda self: self._parse_chr(), - "GROUP_CONCAT": lambda self: self.expression( - exp.GroupConcat, - this=self._parse_lambda(), - separator=self._match(TokenType.SEPARATOR) and self._parse_field(), - ), + "GROUP_CONCAT": lambda self: self._parse_group_concat(), # https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values "VALUES": lambda self: self.expression( exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()] @@ -412,6 +410,11 @@ class MySQL(Dialect): "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"), } + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, + "MODIFY": lambda self: self._parse_alter_table_alter(), + } + SCHEMA_UNNAMED_CONSTRAINTS = { *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, "FULLTEXT", @@ -458,7 +461,7 @@ class MySQL(Dialect): this = self._parse_id_var(any_token=False) index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text - schema = self._parse_schema() + expressions = self._parse_wrapped_csv(self._parse_ordered) options = [] while True: @@ -478,9 +481,6 @@ class MySQL(Dialect): elif self._match_text_seq("ENGINE_ATTRIBUTE"): self._match(TokenType.EQ) opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) - elif self._match_text_seq("ENGINE_ATTRIBUTE"): - self._match(TokenType.EQ) - opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"): self._match(TokenType.EQ) opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string()) @@ -495,7 +495,7 @@ class MySQL(Dialect): return self.expression( exp.IndexColumnConstraint, this=this, - schema=schema, + expressions=expressions, kind=kind, index_type=index_type, options=options, @@ -617,6 +617,39 @@ class MySQL(Dialect): return self.expression(exp.Chr, **kwargs) + def _parse_group_concat(self) -> t.Optional[exp.Expression]: + def concat_exprs( + node: t.Optional[exp.Expression], exprs: t.List[exp.Expression] + ) -> exp.Expression: + if isinstance(node, exp.Distinct) and len(node.expressions) > 1: + concat_exprs = [ + self.expression(exp.Concat, expressions=node.expressions, safe=True) + ] + node.set("expressions", concat_exprs) + return node + if len(exprs) == 1: + return exprs[0] + return self.expression(exp.Concat, expressions=args, safe=True) + + args = self._parse_csv(self._parse_lambda) + + if args: + order = args[-1] if isinstance(args[-1], exp.Order) else None + + if order: + # Order By is the last (or only) expression in the list and has consumed the 'expr' before it, + # remove 'expr' from exp.Order and add it back to args + args[-1] = order.this + order.set("this", concat_exprs(order.this, args)) + + this = order or concat_exprs(args[0], args) + else: + this = None + + separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None + + return self.expression(exp.GroupConcat, this=this, separator=separator) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = None @@ -630,6 +663,7 @@ class MySQL(Dialect): JSON_TYPE_REQUIRED_FOR_EXTRACTION = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False JSON_KEY_VALUE_PAIR_SEP = "," + SUPPORTS_TO_NUMBER = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -637,9 +671,9 @@ class MySQL(Dialect): exp.DateDiff: _remove_ts_or_ds_to_date( lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") ), - exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")), + exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")), + exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")), exp.DateTrunc: _date_trunc_sql, exp.Day: _remove_ts_or_ds_to_date(), exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), @@ -672,7 +706,7 @@ class MySQL(Dialect): exp.TimeFromParts: rename_func("MAKETIME"), exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), exp.TimestampDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + "TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this ), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -682,9 +716,10 @@ class MySQL(Dialect): ), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsAdd: date_add_sql("ADD"), exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), exp.Week: _remove_ts_or_ds_to_date(), exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), exp.Year: _remove_ts_or_ds_to_date(), @@ -751,11 +786,6 @@ class MySQL(Dialect): result = f"{result} UNSIGNED" return result - def xor_sql(self, expression: exp.Xor) -> str: - if expression.expressions: - return self.expressions(expression, sep=" XOR ") - return super().xor_sql(expression) - def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str: return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index bccdad0..e038400 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( build_formatted_time, no_ilike_sql, rename_func, + to_number_with_nls_param, trim_sql, ) from sqlglot.helper import seq_get @@ -246,6 +247,7 @@ class Oracle(Dialect): exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY", exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.ToNumber: to_number_with_nls_param, exp.Trim: trim_sql, exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index b53ae07..11398ed 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -278,6 +278,7 @@ class Postgres(Dialect): "REVOKE": TokenType.COMMAND, "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, + "NAME": TokenType.NAME, "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, "OID": TokenType.OBJECT_IDENTIFIER, @@ -356,6 +357,16 @@ class Postgres(Dialect): JSON_ARROWS_REQUIRE_JSON_TYPE = True + COLUMN_OPERATORS = { + **parser.Parser.COLUMN_OPERATORS, + TokenType.ARROW: lambda self, this, path: build_json_extract_path( + exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE + )([this, path]), + TokenType.DARROW: lambda self, this, path: build_json_extract_path( + exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE + )([this, path]), + } + def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: while True: if not self._match(TokenType.L_PAREN): @@ -484,6 +495,7 @@ class Postgres(Dialect): ] ), exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 3649bd2..25bba96 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -22,10 +22,13 @@ from sqlglot.dialects.dialect import ( rename_func, right_to_substring_sql, struct_extract_sql, + str_position_sql, timestamptrunc_sql, timestrtotime_sql, ts_or_ds_add_cast, + unit_to_str, ) +from sqlglot.dialects.hive import Hive from sqlglot.dialects.mysql import MySQL from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -93,14 +96,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: expression = ts_or_ds_add_cast(expression) - unit = exp.Literal.string(expression.text("unit") or "DAY") + unit = unit_to_str(expression) return self.func("DATE_ADD", unit, expression.expression, expression.this) def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: this = exp.cast(expression.this, "TIMESTAMP") expr = exp.cast(expression.expression, "TIMESTAMP") - unit = exp.Literal.string(expression.text("unit") or "DAY") + unit = unit_to_str(expression) return self.func("DATE_DIFF", unit, expr, this) @@ -196,6 +199,7 @@ class Presto(Dialect): SUPPORTS_SEMI_ANTI_JOIN = False TYPED_DIVISION = True TABLESAMPLE_SIZE_IS_PERCENT = True + LOG_BASE_FIRST: t.Optional[bool] = None # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -289,6 +293,7 @@ class Presto(Dialect): SUPPORTS_SINGLE_ARG_CONCAT = False LIKE_PROPERTY_INSIDE_SCHEMA = True MULTI_ARG_DISTINCT = False + SUPPORTS_TO_NUMBER = False PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, @@ -323,6 +328,7 @@ class Presto(Dialect): exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), + exp.ArrayToString: rename_func("ARRAY_JOIN"), exp.ArrayUniqueAgg: rename_func("SET_AGG"), exp.AtTimeZone: rename_func("AT_TIMEZONE"), exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression), @@ -339,19 +345,19 @@ class Presto(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "DAY"), + unit_to_str(e), _to_int(e.expression), e.this, ), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this + "DATE_DIFF", unit_to_str(e), e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "DAY"), + unit_to_str(e), _to_int(e.expression * -1), e.this, ), @@ -397,13 +403,10 @@ class Presto(Dialect): ] ), exp.SortArray: _no_sort_array, - exp.StrPosition: rename_func("STRPOS"), + exp.StrPosition: lambda self, e: str_position_sql(self, e, generate_instance=True), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, - exp.StrToUnix: lambda self, e: self.func( - "TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e)) - ), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.Timestamp: no_timestamp_sql, @@ -436,6 +439,22 @@ class Presto(Dialect): exp.Xor: bool_xor_sql, } + def strtounix_sql(self, expression: exp.StrToUnix) -> str: + # Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one. + # To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a + # timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback, + # which seems to be using the same time mapping as Hive, as per: + # https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html + value_as_text = exp.cast(expression.this, "text") + parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression)) + parse_with_tz = self.func( + "PARSE_DATETIME", + value_as_text, + self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE), + ) + coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz) + return self.func("TO_UNIXTIME", coalesced) + def bracket_sql(self, expression: exp.Bracket) -> str: if expression.args.get("safe"): return self.func( @@ -481,8 +500,7 @@ class Presto(Dialect): return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))" def interval_sql(self, expression: exp.Interval) -> str: - unit = self.sql(expression, "unit") - if expression.this and unit.startswith("WEEK"): + if expression.this and expression.text("unit").upper().startswith("WEEK"): return f"({expression.this.name} * INTERVAL '7' DAY)" return super().interval_sql(expression) diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py new file mode 100644 index 0000000..3005753 --- /dev/null +++ b/sqlglot/dialects/prql.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp, parser, tokens +from sqlglot.dialects.dialect import Dialect +from sqlglot.tokens import TokenType + + +class PRQL(Dialect): + class Tokenizer(tokens.Tokenizer): + IDENTIFIERS = ["`"] + QUOTES = ["'", '"'] + + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "=": TokenType.ALIAS, + "'": TokenType.QUOTE, + '"': TokenType.QUOTE, + "`": TokenType.IDENTIFIER, + "#": TokenType.COMMENT, + } + + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + } + + class Parser(parser.Parser): + TRANSFORM_PARSERS = { + "DERIVE": lambda self, query: self._parse_selection(query), + "SELECT": lambda self, query: self._parse_selection(query, append=False), + "TAKE": lambda self, query: self._parse_take(query), + } + + def _parse_statement(self) -> t.Optional[exp.Expression]: + expression = self._parse_expression() + expression = expression if expression else self._parse_query() + return expression + + def _parse_query(self) -> t.Optional[exp.Query]: + from_ = self._parse_from() + + if not from_: + return None + + query = exp.select("*").from_(from_, copy=False) + + while self._match_texts(self.TRANSFORM_PARSERS): + query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query) + + return query + + def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query: + if self._match(TokenType.L_BRACE): + selects = self._parse_csv(self._parse_expression) + + if not self._match(TokenType.R_BRACE, expression=query): + self.raise_error("Expecting }") + else: + expression = self._parse_expression() + selects = [expression] if expression else [] + + projections = { + select.alias_or_name: select.this if isinstance(select, exp.Alias) else select + for select in query.selects + } + + selects = [ + select.transform( + lambda s: (projections[s.name].copy() if s.name in projections else s) + if isinstance(s, exp.Column) + else s, + copy=False, + ) + for select in selects + ] + + return query.select(*selects, append=append, copy=False) + + def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]: + num = self._parse_number() # TODO: TAKE for ranges a..b + return query.limit(num) if num else None + + def _parse_expression(self) -> t.Optional[exp.Expression]: + if self._next and self._next.token_type == TokenType.ALIAS: + alias = self._parse_id_var(True) + self._match(TokenType.ALIAS) + return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias) + return self._parse_conjunction() + + def _parse_table( + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, + is_db_reference: bool = False, + ) -> t.Optional[exp.Expression]: + return self._parse_table_parts() + + def _parse_from( + self, joins: bool = False, skip_from_token: bool = False + ) -> t.Optional[exp.From]: + if not skip_from_token and not self._match(TokenType.FROM): + return None + + return self.expression( + exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) + ) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 0db87ec..1f0c411 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -92,23 +92,6 @@ class Redshift(Postgres): return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table - def _parse_types( - self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True - ) -> t.Optional[exp.Expression]: - this = super()._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - - if ( - isinstance(this, exp.DataType) - and this.is_type("varchar") - and this.expressions - and this.expressions[0].this == exp.column("MAX") - ): - this.set("expressions", [exp.var("MAX")]) - - return this - def _parse_convert( self, strict: bool, safe: t.Optional[bool] = None ) -> t.Optional[exp.Expression]: @@ -153,6 +136,7 @@ class Redshift(Postgres): NVL2_SUPPORTED = True LAST_DAY_SUPPORTS_DATE_PART = False CAN_IMPLEMENT_ARRAY_ANY = False + MULTI_ARG_DISTINCT = True TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -187,9 +171,13 @@ class Redshift(Postgres): ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.StartsWith: lambda self, + e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'", exp.TableSample: no_tablesample_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD"), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), + exp.UnixToTime: lambda self, + e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')", } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 20fdfb7..73a9166 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, var_map_sql, ) -from sqlglot.expressions import Literal -from sqlglot.helper import flatten, is_int, seq_get +from sqlglot.helper import flatten, is_float, is_int, seq_get from sqlglot.tokens import TokenType if t.TYPE_CHECKING: @@ -29,33 +28,35 @@ if t.TYPE_CHECKING: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: - if len(args) == 2: - first_arg, second_arg = args - if second_arg.is_string: - # case: [ , ] - return build_formatted_time(exp.StrToTime, "snowflake")(args) - return exp.UnixToTime(this=first_arg, scale=second_arg) +def _build_datetime( + name: str, kind: exp.DataType.Type, safe: bool = False +) -> t.Callable[[t.List], exp.Func]: + def _builder(args: t.List) -> exp.Func: + value = seq_get(args, 0) + + if isinstance(value, exp.Literal): + int_value = is_int(value.this) - from sqlglot.optimizer.simplify import simplify_literals + # Converts calls like `TO_TIME('01:02:03')` into casts + if len(args) == 1 and value.is_string and not int_value: + return exp.cast(value, kind) - # The first argument might be an expression like 40 * 365 * 86400, so we try to - # reduce it using `simplify_literals` first and then check if it's a Literal. - first_arg = seq_get(args, 0) - if not isinstance(simplify_literals(first_arg, root=True), Literal): - # case: or other expressions such as columns - return exp.TimeStrToTime.from_arg_list(args) + # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special + # cases so we can transpile them, since they're relatively common + if kind == exp.DataType.Type.TIMESTAMP: + if int_value: + return exp.UnixToTime(this=value, scale=seq_get(args, 1)) + if not is_float(value.this): + return build_formatted_time(exp.StrToTime, "snowflake")(args) - if first_arg.is_string: - if is_int(first_arg.this): - # case: - return exp.UnixToTime.from_arg_list(args) + if len(args) == 2 and kind == exp.DataType.Type.DATE: + formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args) + formatted_exp.set("safe", safe) + return formatted_exp - # case: - return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) + return exp.Anonymous(this=name, expressions=args) - # case: - return exp.UnixToTime.from_arg_list(args) + return _builder def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: @@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff: ) +def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: + def _builder(args: t.List) -> E: + return expr_type( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=_map_date_part(seq_get(args, 0)), + ) + + return _builder + + # https://docs.snowflake.com/en/sql-reference/functions/div0 def _build_if_from_div0(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) @@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If: return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str: - if expression.is_type("array"): - return "ARRAY" - elif expression.is_type("map"): - return "OBJECT" - return self.datatype_sql(expression) - - def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: flag = expression.text("flag") @@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + + def _flatten_structured_type(expression: exp.DataType) -> exp.DataType: + if expression.this in exp.DataType.NESTED_TYPES: + expression.set("expressions", None) + return expression + + props = expression.args.get("properties") + if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)): + for schema_expression in expression.this.expressions: + if isinstance(schema_expression, exp.ColumnDef): + column_type = schema_expression.kind + if isinstance(column_type, exp.DataType): + column_type.transform(_flatten_structured_type, copy=False) + + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -312,7 +335,13 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + ID_VAR_TOKENS = { + *parser.Parser.ID_VAR_TOKENS, + TokenType.MATCH_CONDITION, + } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW} + TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION) FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -327,17 +356,13 @@ class Snowflake(Dialect): end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)), step=seq_get(args, 2), ), - "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "BITXOR": binary_from_function(exp.BitwiseXor), "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _build_convert_timezone, + "DATE": _build_datetime("DATE", exp.DataType.Type.DATE), "DATE_TRUNC": _date_trunc_to_time, - "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=_map_date_part(seq_get(args, 0)), - ), + "DATEADD": _build_date_time_add(exp.DateAdd), "DATEDIFF": _build_datediff, "DIV0": _build_if_from_div0, "FLATTEN": exp.Explode.from_arg_list, @@ -349,17 +374,34 @@ class Snowflake(Dialect): this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) ), "LISTAGG": exp.GroupConcat.from_arg_list, + "MEDIAN": lambda args: exp.PercentileCont( + this=seq_get(args, 0), expression=exp.Literal.number(0.5) + ), "NULLIFZERO": _build_if_from_nullifzero, "OBJECT_CONSTRUCT": _build_object_construct, "REGEXP_REPLACE": _build_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), + "TIMEADD": _build_date_time_add(exp.TimeAdd), "TIMEDIFF": _build_datediff, + "TIMESTAMPADD": _build_date_time_add(exp.DateAdd), "TIMESTAMPDIFF": _build_datediff, "TIMESTAMPFROMPARTS": _build_timestamp_from_parts, "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts, - "TO_TIMESTAMP": _build_to_timestamp, + "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True), + "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE), + "TO_NUMBER": lambda args: exp.ToNumber( + this=seq_get(args, 0), + format=seq_get(args, 1), + precision=seq_get(args, 2), + scale=seq_get(args, 3), + ), + "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME), + "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP), + "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ), + "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP), + "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ), "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _build_if_from_zeroifnull, } @@ -377,7 +419,6 @@ class Snowflake(Dialect): **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), - TokenType.COLON: lambda self, this: self._parse_colon_get_path(this), } ALTER_PARSERS = { @@ -434,35 +475,35 @@ class Snowflake(Dialect): SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"} - def _parse_colon_get_path( - self: parser.Parser, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - while True: - path = self._parse_bitwise() or self._parse_var(any_token=True) + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + this = super()._parse_column_ops(this) + + casts = [] + json_path = [] + + while self._match(TokenType.COLON): + path = super()._parse_column_ops(self._parse_field(any_token=True)) # The cast :: operator has a lower precedence than the extraction operator :, so # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH - if isinstance(path, exp.Cast): - target_type = path.to + while isinstance(path, exp.Cast): + casts.append(path.to) path = path.this - else: - target_type = None - if isinstance(path, exp.Expression): - path = exp.Literal.string(path.sql(dialect="snowflake")) + if path: + json_path.append(path.sql(dialect="snowflake", copy=False)) - # The extraction operator : is left-associative + if json_path: this = self.expression( - exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + exp.JSONExtract, + this=this, + expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))), ) - if target_type: - this = exp.cast(this, target_type) + while casts: + this = self.expression(exp.Cast, this=this, to=casts.pop()) - if not self._match(TokenType.COLON): - break - - return self._parse_range(this) + return this # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts @@ -663,6 +704,7 @@ class Snowflake(Dialect): "EXCLUDE": TokenType.EXCEPT, "ILIKE ANY": TokenType.ILIKE_ANY, "LIKE ANY": TokenType.LIKE_ANY, + "MATCH_CONDITION": TokenType.MATCH_CONDITION, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NCHAR VARYING": TokenType.VARCHAR, @@ -703,6 +745,7 @@ class Snowflake(Dialect): LIMIT_ONLY_LITERALS = True JSON_KEY_VALUE_PAIR_SEP = "," INSERT_OVERWRITE = " OVERWRITE INTO" + STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -711,15 +754,14 @@ class Snowflake(Dialect): exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this), - exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), exp.BitwiseXor: rename_func("BITXOR"), + exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]), exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DateStrToDate: datestrtodate_sql, - exp.DataType: _datatype_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -769,6 +811,7 @@ class Snowflake(Dialect): ), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.Stuff: rename_func("INSERT"), + exp.TimeAdd: date_delta_sql("TIMEADD"), exp.TimestampDiff: lambda self, e: self.func( "TIMESTAMPDIFF", e.unit, e.expression, e.this ), @@ -783,6 +826,9 @@ class Snowflake(Dialect): exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), + exp.TsOrDsToDate: lambda self, e: self.func( + "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e) + ), exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), @@ -797,6 +843,8 @@ class Snowflake(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.NESTED: "OBJECT", + exp.DataType.Type.STRUCT: "OBJECT", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -811,6 +859,37 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + UNSUPPORTED_VALUES_EXPRESSIONS = { + exp.Struct, + } + + def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: + if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS): + values_as_table = False + + return super().values_sql(expression, values_as_table=values_as_table) + + def datatype_sql(self, expression: exp.DataType) -> str: + expressions = expression.expressions + if ( + expressions + and expression.is_type(*exp.DataType.STRUCT_TYPES) + and any(isinstance(field_type, exp.DataType) for field_type in expressions) + ): + # The correct syntax is OBJECT [ ( str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("precision"), + expression.args.get("scale"), + ) + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: milli = expression.args.get("milli") if milli is not None: diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 20c0fce..88b5ddc 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import rename_func, unit_to_var from sqlglot.dialects.hive import _build_with_ignore_nulls from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get @@ -78,6 +78,8 @@ class Spark(Spark2): return this class Generator(Spark2.Generator): + SUPPORTS_TO_NUMBER = True + TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, exp.DataType.Type.MONEY: "DECIMAL(15, 4)", @@ -100,7 +102,7 @@ class Spark(Spark2): e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( - "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this + "DATEADD", unit_to_var(e), e.expression, e.this ), exp.TryCast: lambda self, e: ( self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) @@ -117,11 +119,10 @@ class Spark(Spark2): return self.function_fallback_sql(expression) def datediff_sql(self, expression: exp.DateDiff) -> str: - unit = self.sql(expression, "unit") end = self.sql(expression, "this") start = self.sql(expression, "expression") - if unit: - return self.func("DATEDIFF", unit, start, end) + if expression.unit: + return self.func("DATEDIFF", unit_to_var(expression), start, end) return self.func("DATEDIFF", end, start) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 63eae6e..069916f 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( pivot_column_names, rename_func, trim_sql, + unit_to_str, ) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get @@ -203,6 +204,7 @@ class Spark2(Hive): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.ArrayToString: rename_func("ARRAY_JOIN"), exp.AtTimeZone: lambda self, e: self.func( "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone") ), @@ -218,7 +220,7 @@ class Spark2(Hive): ] ), exp.DateFromParts: rename_func("MAKE_DATE"), - exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), + exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)), exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -241,9 +243,7 @@ class Spark2(Hive): ), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this - ), + exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), exp.Trim: trim_sql, exp.UnixToTime: _unix_to_time_sql, exp.VariancePop: rename_func("VAR_POP"), @@ -252,7 +252,6 @@ class Spark2(Hive): [transforms.remove_within_group_for_percentiles] ), } - TRANSFORMS.pop(exp.ArrayJoin) TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) TRANSFORMS.pop(exp.Left) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 2b17ff9..ef7d9aa 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -33,6 +33,14 @@ def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> st return arrow_json_extract_sql(self, expression) +def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr: + if len(args) == 1: + args.append(exp.CurrentTimestamp()) + if len(args) == 2: + return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0]) + return exp.Anonymous(this="STRFTIME", expressions=args) + + def _transform_create(expression: exp.Expression) -> exp.Expression: """Move primary key to a column and enforce auto_increment on primary keys.""" schema = expression.this @@ -82,6 +90,7 @@ class SQLite(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, + "STRFTIME": _build_strftime, } STRING_ALIASES = True @@ -93,6 +102,7 @@ class SQLite(Dialect): JSON_PATH_BRACKETED_KEY_SUPPORTED = False SUPPORTS_CREATE_TABLE_LIKE = False SUPPORTS_TABLE_ALIAS_COLUMNS = False + SUPPORTS_TO_NUMBER = False SUPPORTED_JSON_PATH_PARTS = { exp.JSONPathKey, @@ -151,7 +161,9 @@ class SQLite(Dialect): ), exp.TableSample: no_tablesample_sql, exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), + exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this), exp.TryCast: no_trycast_sql, + exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"), } # SQLite doesn't generally support CREATE TABLE .. properties diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 12ac600..5691f58 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, build_timestamp_trunc, rename_func, + unit_to_str, ) from sqlglot.dialects.mysql import MySQL from sqlglot.helper import seq_get @@ -39,15 +40,13 @@ class StarRocks(MySQL): **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression + "DATE_DIFF", unit_to_str(e), e.this, e.expression ), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.RegexpLike: rename_func("REGEXP"), exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this - ), + exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), exp.TimeStrToDate: rename_func("TO_DATE"), exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), exp.UnixToTime: rename_func("FROM_UNIXTIME"), diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index b736918..40feb67 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -5,6 +5,8 @@ from sqlglot.dialects.dialect import Dialect, rename_func class Tableau(Dialect): + LOG_BASE_FIRST = False + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = [("[", "]")] QUOTES = ["'", '"'] diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 0663a1d..a65e10e 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -3,7 +3,13 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func +from sqlglot.dialects.dialect import ( + Dialect, + max_or_greatest, + min_or_least, + rename_func, + to_number_with_nls_param, +) from sqlglot.tokens import TokenType @@ -206,6 +212,7 @@ class Teradata(Dialect): exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.ToNumber: to_number_with_nls_param, exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 1bbed67..457e2f0 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -7,6 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): SUPPORTS_USER_DEFINED_TYPES = False + LOG_BASE_FIRST = True class Generator(Presto.Generator): TRANSFORMS = { diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b6f491f..8e06be6 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( NormalizationStrategy, any_value_to_max_sql, date_delta_sql, + datestrtodate_sql, generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, @@ -724,6 +725,7 @@ class TSQL(Dialect): TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" SUPPORTS_SELECT_INTO = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False + SUPPORTS_TO_NUMBER = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -760,12 +762,14 @@ class TSQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, + exp.ArrayToString: rename_func("STRING_AGG"), exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.CTE: transforms.preprocess([qualify_derived_table_outputs]), exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), + exp.DateStrToDate: datestrtodate_sql, exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, @@ -808,6 +812,22 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def select_sql(self, expression: exp.Select) -> str: + if expression.args.get("offset"): + if not expression.args.get("order"): + # ORDER BY is required in order to use OFFSET in a query, so we use + # a noop order by, since we don't really care about the order. + # See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819 + expression.order_by(exp.select(exp.null()).subquery(), copy=False) + + limit = expression.args.get("limit") + if isinstance(limit, exp.Limit): + # TOP and OFFSET can't be combined, we need use FETCH instead of TOP + # we replace here because otherwise TOP would be generated in select_sql + limit.replace(exp.Fetch(direction="FIRST", count=limit.expression)) + + return super().select_sql(expression) + def convert_sql(self, expression: exp.Convert) -> str: name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT" return self.func( @@ -862,12 +882,12 @@ class TSQL(Dialect): return rename_func("DATETIMEFROMPARTS")(self, expression) - def set_operation(self, expression: exp.Union, op: str) -> str: + def set_operations(self, expression: exp.Union) -> str: limit = expression.args.get("limit") if limit: return self.sql(expression.limit(limit.pop(), copy=False)) - return super().set_operation(expression, op) + return super().set_operations(expression) def setitem_sql(self, expression: exp.SetItem) -> str: this = expression.this -- cgit v1.2.3