diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 65 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 201 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 34 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 35 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 40 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 17 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 64 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 19 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/tableau.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 44 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 34 |
17 files changed, 351 insertions, 316 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5b10852..2166e65 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -7,6 +7,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, + format_time_lambda, inline_array_sql, max_or_greatest, min_or_least, @@ -103,16 +104,26 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: class BigQuery(Dialect): - unnest_column_only = True - time_mapping = { - "%M": "%-M", - "%d": "%-d", - "%m": "%-m", - "%y": "%-y", - "%H": "%-H", - "%I": "%-I", - "%S": "%-S", - "%j": "%-j", + UNNEST_COLUMN_ONLY = True + + TIME_MAPPING = { + "%D": "%m/%d/%y", + } + + FORMAT_MAPPING = { + "DD": "%d", + "MM": "%m", + "MON": "%b", + "MONTH": "%B", + "YYYY": "%Y", + "YY": "%y", + "HH": "%I", + "HH12": "%I", + "HH24": "%H", + "MI": "%M", + "SS": "%S", + "SSSSS": "%f", + "TZH": "%z", } class Tokenizer(tokens.Tokenizer): @@ -142,6 +153,7 @@ class BigQuery(Dialect): "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, "RECORD": TokenType.STRUCT, + "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, } @@ -155,13 +167,21 @@ class BigQuery(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), ), - "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), + "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), + "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ), + "PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")( + [seq_get(args, 1), seq_get(args, 0)] + ), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -172,15 +192,15 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), + "SPLIT": lambda args: exp.Split( + # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split + this=seq_get(args, 0), + expression=seq_get(args, 1) or exp.Literal.string(","), + ), "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), - "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), - "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), - "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), - "PARSE_TIMESTAMP": lambda args: exp.StrToTime( - this=seq_get(args, 1), format=seq_get(args, 0) - ), } FUNCTION_PARSERS = { @@ -274,9 +294,18 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.RegexpExtract: lambda self, e: self.func( + "REGEXP_EXTRACT", + e.this, + e.expression, + e.args.get("position"), + e.args.get("occurrence"), + ), + exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.Select: transforms.preprocess( [_unqualify_unnest, transforms.eliminate_distinct_on] ), + exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -295,7 +324,6 @@ class BigQuery(Dialect): exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", - exp.RegexpLike: rename_func("REGEXP_CONTAINS"), } TYPE_MAPPING = { @@ -315,6 +343,7 @@ class BigQuery(Dialect): exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARBINARY: "BYTES", exp.DataType.Type.VARCHAR: "STRING", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index fc48379..cfa9a7e 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -21,8 +21,9 @@ def _lower_func(sql: str) -> str: class ClickHouse(Dialect): - normalize_functions = None - null_ordering = "nulls_are_last" + NORMALIZE_FUNCTIONS: bool | str = False + NULL_ORDERING = "nulls_are_last" + STRICT_STRING_CONCAT = True class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] @@ -163,11 +164,11 @@ class ClickHouse(Dialect): return this - def _parse_position(self, haystack_first: bool = False) -> exp.Expression: + def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: return super()._parse_position(haystack_first=True) # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ - def _parse_cte(self) -> exp.Expression: + def _parse_cte(self) -> exp.CTE: index = self._index try: # WITH <identifier> AS <subquery expression> @@ -187,17 +188,19 @@ class ClickHouse(Dialect): ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: is_global = self._match(TokenType.GLOBAL) and self._prev kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev + if kind_pre: kind = self._match_set(self.JOIN_KINDS) and self._prev side = self._match_set(self.JOIN_SIDES) and self._prev return is_global, side, kind + return ( is_global, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]: join = super()._parse_join(skip_join_token) if join: @@ -205,9 +208,14 @@ class ClickHouse(Dialect): return join def _parse_function( - self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, ) -> t.Optional[exp.Expression]: - func = super()._parse_function(functions, anonymous) + func = super()._parse_function( + functions=functions, anonymous=anonymous, optional_parens=optional_parens + ) if isinstance(func, exp.Anonymous): params = self._parse_func_params(func) @@ -227,10 +235,12 @@ class ClickHouse(Dialect): ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): return self._parse_csv(self._parse_lambda) + if self._match(TokenType.L_PAREN): params = self._parse_csv(self._parse_lambda) self._match_r_paren(this) return params + return None def _parse_quantile(self) -> exp.Quantile: @@ -247,12 +257,12 @@ class ClickHouse(Dialect): def _parse_primary_key( self, wrapped_optional: bool = False, in_props: bool = False - ) -> exp.Expression: + ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: return super()._parse_primary_key( wrapped_optional=wrapped_optional or in_props, in_props=in_props ) - def _parse_on_property(self) -> t.Optional[exp.Property]: + def _parse_on_property(self) -> t.Optional[exp.Expression]: index = self._index if self._match_text_seq("CLUSTER"): this = self._parse_id_var() @@ -329,6 +339,16 @@ class ClickHouse(Dialect): "NAMED COLLECTION", } + def safeconcat_sql(self, expression: exp.SafeConcat) -> str: + # Clickhouse errors out if we try to cast a NULL value to TEXT + return self.func( + "CONCAT", + *[ + exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text")) + for e in expression.expressions + ], + ) + def cte_sql(self, expression: exp.CTE) -> str: if isinstance(expression.this, exp.Alias): return self.sql(expression, "this") diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4958bc6..f5d523b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -25,6 +25,8 @@ class Dialects(str, Enum): BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" + DATABRICKS = "databricks" + DRILL = "drill" DUCKDB = "duckdb" HIVE = "hive" MYSQL = "mysql" @@ -38,11 +40,9 @@ class Dialects(str, Enum): SQLITE = "sqlite" STARROCKS = "starrocks" TABLEAU = "tableau" + TERADATA = "teradata" TRINO = "trino" TSQL = "tsql" - DATABRICKS = "databricks" - DRILL = "drill" - TERADATA = "teradata" class _Dialect(type): @@ -76,16 +76,19 @@ class _Dialect(type): enum = Dialects.__members__.get(clsname.upper()) cls.classes[enum.value if enum is not None else clsname.lower()] = klass - klass.time_trie = new_trie(klass.time_mapping) - klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} - klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) + klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) + klass.FORMAT_TRIE = ( + new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE + ) + klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} + klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) klass.parser_class = getattr(klass, "Parser", Parser) klass.generator_class = getattr(klass, "Generator", Generator) - klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list( + klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] + klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( klass.tokenizer_class._IDENTIFIERS.items() )[0] @@ -99,43 +102,80 @@ class _Dialect(type): (None, None), ) - klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) - klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) - klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) - klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) + klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) + klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) + klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) + klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) - klass.tokenizer_class.identifiers_can_start_with_digit = ( - klass.identifiers_can_start_with_digit - ) + dialect_properties = { + **{ + k: v + for k, v in vars(klass).items() + if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") + }, + "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], + "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], + } + + # Pass required dialect properties to the tokenizer, parser and generator classes + for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): + for name, value in dialect_properties.items(): + if hasattr(subclass, name): + setattr(subclass, name, value) + + if not klass.STRICT_STRING_CONCAT: + klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe return klass class Dialect(metaclass=_Dialect): - index_offset = 0 - unnest_column_only = False - alias_post_tablesample = False - identifiers_can_start_with_digit = False - normalize_functions: t.Optional[str] = "upper" - null_ordering = "nulls_are_small" - - date_format = "'%Y-%m-%d'" - dateint_format = "'%Y%m%d'" - time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping: t.Dict[str, str] = {} - - # autofilled - quote_start = None - quote_end = None - identifier_start = None - identifier_end = None - - time_trie = None - inverse_time_mapping = None - inverse_time_trie = None - tokenizer_class = None - parser_class = None - generator_class = None + # Determines the base index offset for arrays + INDEX_OFFSET = 0 + + # If true unnest table aliases are considered only as column aliases + UNNEST_COLUMN_ONLY = False + + # Determines whether or not the table alias comes after tablesample + ALIAS_POST_TABLESAMPLE = False + + # Determines whether or not an unquoted identifier can start with a digit + IDENTIFIERS_CAN_START_WITH_DIGIT = False + + # Determines whether or not CONCAT's arguments must be strings + STRICT_STRING_CONCAT = False + + # Determines how function names are going to be normalized + NORMALIZE_FUNCTIONS: bool | str = "upper" + + # Indicates the default null ordering method to use if not explicitly set + # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" + NULL_ORDERING = "nulls_are_small" + + DATE_FORMAT = "'%Y-%m-%d'" + DATEINT_FORMAT = "'%Y%m%d'" + TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" + + # Custom time mappings in which the key represents dialect time format + # and the value represents a python time format + TIME_MAPPING: t.Dict[str, str] = {} + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time + # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE + # special syntax cast(x as date format 'yyyy') defaults to time_mapping + FORMAT_MAPPING: t.Dict[str, str] = {} + + # Autofilled + tokenizer_class = Tokenizer + parser_class = Parser + generator_class = Generator + + # A trie of the time_mapping keys + TIME_TRIE: t.Dict = {} + FORMAT_TRIE: t.Dict = {} + + INVERSE_TIME_MAPPING: t.Dict[str, str] = {} + INVERSE_TIME_TRIE: t.Dict = {} def __eq__(self, other: t.Any) -> bool: return type(self) == other @@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect): ) -> t.Optional[exp.Expression]: if isinstance(expression, str): return exp.Literal.string( - format_time( - expression[1:-1], # the time formats are quoted - cls.time_mapping, - cls.time_trie, - ) + # the time formats are quoted + format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) ) + if expression and expression.is_string: - return exp.Literal.string( - format_time( - expression.this, - cls.time_mapping, - cls.time_trie, - ) - ) + return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) + return expression def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: @@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect): @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() # type: ignore + self._tokenizer = self.tokenizer_class() return self._tokenizer def parser(self, **opts) -> Parser: - return self.parser_class( # type: ignore - **{ - "index_offset": self.index_offset, - "unnest_column_only": self.unnest_column_only, - "alias_post_tablesample": self.alias_post_tablesample, - "null_ordering": self.null_ordering, - **opts, - }, - ) + return self.parser_class(**opts) def generator(self, **opts) -> Generator: - return self.generator_class( # type: ignore - **{ - "quote_start": self.quote_start, - "quote_end": self.quote_end, - "bit_start": self.bit_start, - "bit_end": self.bit_end, - "hex_start": self.hex_start, - "hex_end": self.hex_end, - "byte_start": self.byte_start, - "byte_end": self.byte_end, - "raw_start": self.raw_start, - "raw_end": self.raw_end, - "identifier_start": self.identifier_start, - "identifier_end": self.identifier_end, - "string_escape": self.tokenizer_class.STRING_ESCAPES[0], - "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], - "index_offset": self.index_offset, - "time_mapping": self.inverse_time_mapping, - "time_trie": self.inverse_time_trie, - "unnest_column_only": self.unnest_column_only, - "alias_post_tablesample": self.alias_post_tablesample, - "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, - "normalize_functions": self.normalize_functions, - "null_ordering": self.null_ordering, - **opts, - } - ) + return self.generator_class(**opts) DialectType = t.Union[str, Dialect, t.Type[Dialect], None] @@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like( - this=exp.Lower(this=expression.this), - expression=expression.args["expression"], - ) + exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) ) @@ -359,6 +355,7 @@ def var_map_sql( for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) + return self.func(map_func_name, *args) @@ -381,7 +378,7 @@ def format_time_lambda( this=seq_get(args, 0), format=Dialect[dialect].format_time( seq_get(args, 1) - or (Dialect[dialect].time_format if default is True else default or None) + or (Dialect[dialect].TIME_FORMAT if default is True else default or None) ), ) @@ -437,9 +434,7 @@ def parse_date_delta_with_interval( expression = exp.Literal.number(expression.this) return expression_class( - this=args[0], - expression=expression, - unit=exp.Literal.string(interval.text("unit")), + this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) ) return func @@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), + this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) ) @@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: _dialect = Dialect.get_or_raise(dialect) time_format = self.format_time(expression) - if time_format and time_format not in (_dialect.time_format, _dialect.date_format): + if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): return f"CAST({str_to_time_sql(self, expression)} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)" return _ts_or_ds_to_date_sql +def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: + this, *rest_args = expression.expressions + for arg in rest_args: + this = exp.DPipe(this=this, expression=arg) + + return self.sql(this) + + # Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 924b979..3cca986 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -16,21 +16,10 @@ from sqlglot.dialects.dialect import ( ) -def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: - return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" - - -def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: - time_format = self.format_time(expression) - if time_format and time_format not in (Drill.time_format, Drill.date_format): - return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" - return f"CAST({self.sql(expression, 'this')} AS DATE)" - - def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") - unit = exp.Var(this=expression.text("unit").upper() or "DAY") + unit = exp.var(expression.text("unit").upper() or "DAY") return ( f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" ) @@ -41,19 +30,19 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format == Drill.date_format: + if time_format == Drill.DATE_FORMAT: return f"CAST({this} AS DATE)" return f"TO_DATE({this}, {time_format})" class Drill(Dialect): - normalize_functions = None - null_ordering = "nulls_are_last" - date_format = "'yyyy-MM-dd'" - dateint_format = "'yyyyMMdd'" - time_format = "'yyyy-MM-dd HH:mm:ss'" + NORMALIZE_FUNCTIONS: bool | str = False + NULL_ORDERING = "nulls_are_last" + DATE_FORMAT = "'yyyy-MM-dd'" + DATEINT_FORMAT = "'yyyyMMdd'" + TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" - time_mapping = { + TIME_MAPPING = { "y": "%Y", "Y": "%Y", "YYYY": "%Y", @@ -93,6 +82,7 @@ class Drill(Dialect): class Parser(parser.Parser): STRICT_CAST = False + CONCAT_NULL_OUTPUTS_STRING = True FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -135,8 +125,8 @@ class Drill(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), - exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", - exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", + exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})", exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -154,7 +144,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f31da73..f0c1820 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -56,11 +56,7 @@ def _sort_array_reverse(args: t.List) -> exp.Expression: def _parse_date_diff(args: t.List) -> exp.Expression: - return exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), - ) + return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: @@ -90,7 +86,7 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract class DuckDB(Dialect): - null_ordering = "nulls_are_last" + NULL_ORDERING = "nulls_are_last" class Tokenizer(tokens.Tokenizer): KEYWORDS = { @@ -118,6 +114,8 @@ class DuckDB(Dialect): } class Parser(parser.Parser): + CONCAT_NULL_OUTPUTS_STRING = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, @@ -127,10 +125,7 @@ class DuckDB(Dialect): "DATE_DIFF": _parse_date_diff, "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( - this=exp.Div( - this=seq_get(args, 0), - expression=exp.Literal.number(1000), - ) + this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000)) ), "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, @@ -191,8 +186,8 @@ class DuckDB(Dialect): "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, - exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", - exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, @@ -242,11 +237,27 @@ class DuckDB(Dialect): STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} + UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren) + PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def interval_sql(self, expression: exp.Interval) -> str: + multiplier: t.Optional[int] = None + unit = expression.text("unit").lower() + + if unit.startswith("week"): + multiplier = 7 + if unit.startswith("quarter"): + multiplier = 90 + + if multiplier: + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" + + return super().interval_sql(expression) + def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS " ) -> str: diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 650a1e1..8847119 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -80,12 +80,12 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" + return f"{diff_sql}{multiplier_sql}" def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: this = expression.this - if not this.type: from sqlglot.optimizer.annotate_types import annotate_types @@ -113,7 +113,7 @@ def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> st def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format not in (Hive.time_format, Hive.date_format): + if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" return f"CAST({this} AS DATE)" @@ -121,7 +121,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format not in (Hive.time_format, Hive.date_format): + if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" return f"CAST({this} AS TIMESTAMP)" @@ -130,7 +130,7 @@ def _time_format( self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix ) -> t.Optional[str]: time_format = self.format_time(expression) - if time_format == Hive.time_format: + if time_format == Hive.TIME_FORMAT: return None return time_format @@ -144,16 +144,16 @@ def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format and time_format not in (Hive.time_format, Hive.date_format): + if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): return f"TO_DATE({this}, {time_format})" return f"TO_DATE({this})" class Hive(Dialect): - alias_post_tablesample = True - identifiers_can_start_with_digit = True + ALIAS_POST_TABLESAMPLE = True + IDENTIFIERS_CAN_START_WITH_DIGIT = True - time_mapping = { + TIME_MAPPING = { "y": "%Y", "Y": "%Y", "YYYY": "%Y", @@ -184,9 +184,9 @@ class Hive(Dialect): "EEEE": "%A", } - date_format = "'yyyy-MM-dd'" - dateint_format = "'yyyyMMdd'" - time_format = "'yyyy-MM-dd HH:mm:ss'" + DATE_FORMAT = "'yyyy-MM-dd'" + DATEINT_FORMAT = "'yyyyMMdd'" + TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] @@ -224,9 +224,7 @@ class Hive(Dialect): "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=seq_get(args, 0), - expression=seq_get(args, 1), - unit=exp.Literal.string("DAY"), + this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), "DATEDIFF": lambda args: exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 0)), @@ -234,10 +232,7 @@ class Hive(Dialect): ), "DATE_SUB": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), - expression=exp.Mul( - this=seq_get(args, 1), - expression=exp.Literal.number(-1), - ), + expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)), unit=exp.Literal.string("DAY"), ), "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( @@ -349,8 +344,8 @@ class Hive(Dialect): exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), exp.DateSub: _add_date_sql, - exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", - exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, @@ -415,10 +410,7 @@ class Hive(Dialect): ) def with_properties(self, properties: exp.Properties) -> str: - return self.properties( - properties, - prefix=self.seg("TBLPROPERTIES"), - ) + return self.properties(properties, prefix=self.seg("TBLPROPERTIES")) def datatype_sql(self, expression: exp.DataType) -> str: if ( diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 75023ff..d2462e1 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -94,10 +94,10 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e class MySQL(Dialect): - time_format = "'%Y-%m-%d %T'" + TIME_FORMAT = "'%Y-%m-%d %T'" # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions - time_mapping = { + TIME_MAPPING = { "%M": "%B", "%c": "%-m", "%e": "%-d", @@ -128,6 +128,7 @@ class MySQL(Dialect): "MEDIUMBLOB": TokenType.MEDIUMBLOB, "MEDIUMTEXT": TokenType.MEDIUMTEXT, "SEPARATOR": TokenType.SEPARATOR, + "ENUM": TokenType.ENUM, "START": TokenType.BEGIN, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -279,6 +280,16 @@ class MySQL(Dialect): "SWAPS", } + TYPE_TOKENS = { + *parser.Parser.TYPE_TOKENS, + TokenType.SET, + } + + ENUM_TYPE_TOKENS = { + *parser.Parser.ENUM_TYPE_TOKENS, + TokenType.SET, + } + LOG_DEFAULTS_TO_LN = True def _parse_show_mysql( @@ -372,12 +383,7 @@ class MySQL(Dialect): else: collate = None - return self.expression( - exp.SetItem, - this=charset, - collate=collate, - kind="NAMES", - ) + return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -472,9 +478,7 @@ class MySQL(Dialect): def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str: sql = self.sql(expression, arg) - if not sql: - return "" - return f" {prefix} {sql}" + return f" {prefix} {sql}" if sql else "" def _oldstyle_limit_sql(self, expression: exp.Show) -> str: limit = self.sql(expression, "limit") diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 7722753..8d35e92 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -24,21 +24,15 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: if self._match_text_seq("COLUMNS"): columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True))) - return self.expression( - exp.XMLTable, - this=this, - passing=passing, - columns=columns, - by_ref=by_ref, - ) + return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref) class Oracle(Dialect): - alias_post_tablesample = True + ALIAS_POST_TABLESAMPLE = True # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes - time_mapping = { + TIME_MAPPING = { "AM": "%p", # Meridian indicator with or without periods "A.M.": "%p", # Meridian indicator with or without periods "PM": "%p", # Meridian indicator with or without periods @@ -87,7 +81,7 @@ class Oracle(Dialect): column.set("join_mark", self._match(TokenType.JOIN_MARKER)) return column - def _parse_hint(self) -> t.Optional[exp.Expression]: + def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): start = self._curr while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH): @@ -129,7 +123,7 @@ class Oracle(Dialect): exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, - exp.IfNull: rename_func("NVL"), + exp.Coalesce: rename_func("NVL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), @@ -179,7 +173,6 @@ class Oracle(Dialect): "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NVARCHAR2": TokenType.NVARCHAR, - "RETURNING": TokenType.RETURNING, "SAMPLE": TokenType.TABLE_SAMPLE, "START": TokenType.BEGIN, "TOP": TokenType.TOP, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 8d84024..8c2a4ab 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -183,9 +183,10 @@ def _to_timestamp(args: t.List) -> exp.Expression: class Postgres(Dialect): - null_ordering = "nulls_are_large" - time_format = "'YYYY-MM-DD HH24:MI:SS'" - time_mapping = { + INDEX_OFFSET = 1 + NULL_ORDERING = "nulls_are_large" + TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" + TIME_MAPPING = { "AM": "%p", "PM": "%p", "D": "%u", # 1-based day of week @@ -241,7 +242,6 @@ class Postgres(Dialect): "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, - "RETURNING": TokenType.RETURNING, "REVOKE": TokenType.COMMAND, "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, @@ -258,6 +258,7 @@ class Postgres(Dialect): class Parser(parser.Parser): STRICT_CAST = False + CONCAT_NULL_OUTPUTS_STRING = True FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -268,6 +269,7 @@ class Postgres(Dialect): "NOW": exp.CurrentTimestamp.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_TIMESTAMP": _to_timestamp, + "UNNEST": exp.Explode.from_arg_list, } FUNCTION_PARSERS = { @@ -303,7 +305,7 @@ class Postgres(Dialect): value = self._parse_bitwise() if part and part.is_string: - part = exp.Var(this=part.name) + part = exp.var(part.name) return self.expression(exp.Extract, this=part, expression=value) @@ -328,6 +330,7 @@ class Postgres(Dialect): **generator.Generator.TRANSFORMS, exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), + exp.Explode: rename_func("UNNEST"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index d839864..a8a9884 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -102,7 +102,7 @@ def _str_to_time_sql( def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) - if time_format and time_format not in (Presto.time_format, Presto.date_format): + if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" @@ -119,7 +119,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s exp.Literal.number(1), exp.Literal.number(10), ), - Presto.date_format, + Presto.DATE_FORMAT, ) return self.func( @@ -145,9 +145,7 @@ def _approx_percentile(args: t.List) -> exp.Expression: ) if len(args) == 3: return exp.ApproxQuantile( - this=seq_get(args, 0), - quantile=seq_get(args, 1), - accuracy=seq_get(args, 2), + this=seq_get(args, 0), quantile=seq_get(args, 1), accuracy=seq_get(args, 2) ) return exp.ApproxQuantile.from_arg_list(args) @@ -160,10 +158,8 @@ def _from_unixtime(args: t.List) -> exp.Expression: minutes=seq_get(args, 2), ) if len(args) == 2: - return exp.UnixToTime( - this=seq_get(args, 0), - zone=seq_get(args, 1), - ) + return exp.UnixToTime(this=seq_get(args, 0), zone=seq_get(args, 1)) + return exp.UnixToTime.from_arg_list(args) @@ -173,21 +169,17 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression: unnest = exp.Unnest(expressions=[expression.this]) if expression.alias: - return exp.alias_( - unnest, - alias="_u", - table=[expression.alias], - copy=False, - ) + return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) return unnest return expression class Presto(Dialect): - index_offset = 1 - null_ordering = "nulls_are_last" - time_format = MySQL.time_format - time_mapping = MySQL.time_mapping + INDEX_OFFSET = 1 + NULL_ORDERING = "nulls_are_last" + TIME_FORMAT = MySQL.TIME_FORMAT + TIME_MAPPING = MySQL.TIME_MAPPING + STRICT_STRING_CONCAT = True class Tokenizer(tokens.Tokenizer): KEYWORDS = { @@ -205,14 +197,10 @@ class Presto(Dialect): "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DATE_DIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), @@ -225,9 +213,7 @@ class Presto(Dialect): "NOW": exp.CurrentTimestamp.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( - this=seq_get(args, 0), - substr=seq_get(args, 1), - instance=seq_get(args, 2), + this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_HEX": exp.Hex.from_arg_list, @@ -242,7 +228,7 @@ class Presto(Dialect): INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False TABLE_HINTS = False - IS_BOOL = False + IS_BOOL_ALLOWED = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -284,10 +270,10 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), - exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", - exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.Decode: _decode_sql, - exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", exp.Encode: _encode_sql, exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), @@ -322,7 +308,7 @@ class Presto(Dialect): exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", + exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), @@ -367,8 +353,16 @@ class Presto(Dialect): to = target_type.copy() if target_type is start.to: - end = exp.Cast(this=end, to=to) + end = exp.cast(end, to) else: - start = exp.Cast(this=start, to=to) + start = exp.cast(start, to) return self.func("SEQUENCE", start, end, step) + + def offset_limit_modifiers( + self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] + ) -> t.List[str]: + return [ + self.sql(expression, "offset"), + self.sql(limit), + ] diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b0a6774..a7e25fa 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp, transforms -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -14,9 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx class Redshift(Postgres): - time_format = "'YYYY-MM-DD HH:MI:SS'" - time_mapping = { - **Postgres.time_mapping, + TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" + TIME_MAPPING = { + **Postgres.TIME_MAPPING, "MON": "%b", "HH": "%H", } @@ -51,7 +51,7 @@ class Redshift(Postgres): and this.expressions and this.expressions[0].this == exp.column("MAX") ): - this.set("expressions", [exp.Var(this="MAX")]) + this.set("expressions", [exp.var("MAX")]) return this @@ -94,6 +94,7 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, + exp.Concat: concat_to_dpipe_sql, exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this @@ -106,6 +107,7 @@ class Redshift(Postgres): exp.FromBase: rename_func("STRTOL"), exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, + exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.TsOrDsToDate: lambda self, e: self.sql(e.this), @@ -170,6 +172,6 @@ class Redshift(Postgres): precision = expression.args.get("expressions") if not precision: - expression.append("expressions", exp.Var(this="MAX")) + expression.append("expressions", exp.var("MAX")) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 821d991..148b6d8 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -167,10 +167,10 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression: class Snowflake(Dialect): - null_ordering = "nulls_are_large" - time_format = "'yyyy-mm-dd hh24:mi:ss'" + NULL_ORDERING = "nulls_are_large" + TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" - time_mapping = { + TIME_MAPPING = { "YYYY": "%Y", "yyyy": "%Y", "YY": "%y", @@ -210,14 +210,10 @@ class Snowflake(Dialect): "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DATEDIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=seq_get(args, 0), + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, @@ -246,9 +242,7 @@ class Snowflake(Dialect): COLUMN_OPERATORS = { **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( - exp.Bracket, - this=this, - expressions=[path], + exp.Bracket, this=this, expressions=[path] ), } @@ -275,6 +269,7 @@ class Snowflake(Dialect): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] + COMMENTS = ["--", "//", ("/*", "*/")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index bf24240..ed6992d 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -38,7 +38,7 @@ def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) - if time_format == Hive.date_format: + if time_format == Hive.DATE_FORMAT: return f"TO_DATE({this})" return f"TO_DATE({this}, {time_format})" @@ -133,13 +133,13 @@ class Spark2(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), - "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)), ), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), "BOOLEAN": _parse_as_cast("boolean"), + "DATE": _parse_as_cast("date"), "DOUBLE": _parse_as_cast("double"), "FLOAT": _parse_as_cast("float"), "INT": _parse_as_cast("int"), @@ -162,11 +162,9 @@ class Spark2(Hive): def _parse_add_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - def _parse_drop_column(self) -> t.Optional[exp.Expression]: + def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: return self._match_text_seq("DROP", "COLUMNS") and self.expression( - exp.Drop, - this=self._parse_schema(), - kind="COLUMNS", + exp.Drop, this=self._parse_schema(), kind="COLUMNS" ) def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 4e800b0..3b837ea 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + concat_to_dpipe_sql, count_if_to_sum, no_ilike_sql, no_pivot_sql, @@ -62,10 +63,6 @@ class SQLite(Dialect): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - } - class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -100,6 +97,7 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Concat: concat_to_dpipe_sql, exp.CountIf: count_if_to_sum, exp.Create: transforms.preprocess([_transform_create]), exp.CurrentDate: lambda *_: "CURRENT_DATE", @@ -116,6 +114,7 @@ class SQLite(Dialect): exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), exp.Pivot: no_pivot_sql, + exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_qualify] ), diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index d5fba17..67ef76b 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, transforms -from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.dialect import Dialect, rename_func class Tableau(Dialect): @@ -11,6 +11,7 @@ class Tableau(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Coalesce: rename_func("IFNULL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), } @@ -25,9 +26,6 @@ class Tableau(Dialect): false = self.sql(expression, "false") return f"IF {this} THEN {true} ELSE {false} END" - def coalesce_sql(self, expression: exp.Coalesce) -> str: - return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" - def count_sql(self, expression: exp.Count) -> str: this = expression.this if isinstance(this, exp.Distinct): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 514aecb..d5e5dd8 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,18 +1,32 @@ from __future__ import annotations -import typing as t - from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - format_time_lambda, - max_or_greatest, - min_or_least, -) +from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least from sqlglot.tokens import TokenType class Teradata(Dialect): + TIME_MAPPING = { + "Y": "%Y", + "YYYY": "%Y", + "YY": "%y", + "MMMM": "%B", + "MMM": "%b", + "DD": "%d", + "D": "%-d", + "HH": "%H", + "H": "%-H", + "MM": "%M", + "M": "%-M", + "SS": "%S", + "S": "%-S", + "SSSSSS": "%f", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", + } + class Tokenizer(tokens.Tokenizer): # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance KEYWORDS = { @@ -31,7 +45,7 @@ class Teradata(Dialect): "ST_GEOMETRY": TokenType.GEOMETRY, } - # teradata does not support % for modulus + # Teradata does not support % as a modulo operator SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS} SINGLE_TOKENS.pop("%") @@ -101,7 +115,7 @@ class Teradata(Dialect): # FROM before SET in Teradata UPDATE syntax # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause - def _parse_update(self) -> exp.Expression: + def _parse_update(self) -> exp.Update: return self.expression( exp.Update, **{ # type: ignore @@ -122,14 +136,6 @@ class Teradata(Dialect): return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) - def _parse_cast(self, strict: bool) -> exp.Expression: - cast = t.cast(exp.Cast, super()._parse_cast(strict)) - if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT): - return format_time_lambda(exp.TimeToStr, "teradata")( - [cast.this, self._parse_string()] - ) - return cast - class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False @@ -151,7 +157,7 @@ class Teradata(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), - exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", + exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index f6ad888..6d674f5 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -64,9 +64,9 @@ def _format_time_lambda( format=exp.Literal.string( format_time( args[0].name, - {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} + {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping - else TSQL.time_mapping, + else TSQL.TIME_MAPPING, ) ), ) @@ -86,9 +86,9 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr( this=args[0], format=exp.Literal.string( - format_time(fmt.name, TSQL.format_time_mapping) + format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING) if len(fmt.name) == 1 - else format_time(fmt.name, TSQL.time_mapping) + else format_time(fmt.name, TSQL.TIME_MAPPING) ), ) @@ -138,7 +138,7 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim if isinstance(expression, exp.NumberToStr) else exp.Literal.string( format_time( - expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping) + expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING) ) ) ) @@ -166,10 +166,10 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s class TSQL(Dialect): - null_ordering = "nulls_are_small" - time_format = "'yyyy-mm-dd hh:mm:ss'" + NULL_ORDERING = "nulls_are_small" + TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" - time_mapping = { + TIME_MAPPING = { "year": "%Y", "qq": "%q", "q": "%q", @@ -213,7 +213,7 @@ class TSQL(Dialect): "yy": "%y", } - convert_format_mapping = { + CONVERT_FORMAT_MAPPING = { "0": "%b %d %Y %-I:%M%p", "1": "%m/%d/%y", "2": "%y.%m.%d", @@ -253,8 +253,8 @@ class TSQL(Dialect): "120": "%Y-%m-%d %H:%M:%S", "121": "%Y-%m-%d %H:%M:%S.%f", } - # not sure if complete - format_time_mapping = { + + FORMAT_TIME_MAPPING = { "y": "%B %Y", "d": "%m/%d/%Y", "H": "%-H", @@ -312,9 +312,7 @@ class TSQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), + this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -363,6 +361,8 @@ class TSQL(Dialect): LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True + CONCAT_NULL_OUTPUTS_STRING = True + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -400,7 +400,7 @@ class TSQL(Dialect): table.set("system_time", self._parse_system_time()) return table - def _parse_returns(self) -> exp.Expression: + def _parse_returns(self) -> exp.ReturnsProperty: table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) returns = super()._parse_returns() returns.set("table", table) @@ -423,12 +423,12 @@ class TSQL(Dialect): format_val = self._parse_number() format_val_name = format_val.name if format_val else "" - if format_val_name not in TSQL.convert_format_mapping: + if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING: raise ValueError( f"CONVERT function at T-SQL does not support format style {format_val_name}" ) - format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name]) + format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name]) # Check whether the convert entails a string to date format if to.this == DataType.Type.DATE: |