diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 16 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 51 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 11 |
14 files changed, 106 insertions, 56 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6a43846..a3f9e6d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, inline_array_sql, + max_or_greatest, min_or_least, no_ilike_sql, rename_func, @@ -212,6 +213,9 @@ class BigQuery(Dialect): ), } + LOG_BASE_FIRST = False + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -227,6 +231,7 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Select: transforms.preprocess( [_unqualify_unnest], transforms.delegate("select_sql") @@ -253,17 +258,19 @@ class BigQuery(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.TINYINT: "INT64", - exp.DataType.Type.SMALLINT: "INT64", - exp.DataType.Type.INT: "INT64", exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.CHAR: "STRING", exp.DataType.Type.DECIMAL: "NUMERIC", - exp.DataType.Type.FLOAT: "FLOAT64", exp.DataType.Type.DOUBLE: "FLOAT64", - exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.NCHAR: "STRING", + exp.DataType.Type.NVARCHAR: "STRING", + exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARCHAR: "STRING", - exp.DataType.Type.NVARCHAR: "STRING", } PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore @@ -271,6 +278,7 @@ class BigQuery(Dialect): } EXPLICIT_UNION = True + LIMIT_FETCH = "LIMIT" def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b54a77d..89e2296 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -68,6 +68,8 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + LOG_DEFAULTS_TO_LN = True + def _parse_in( self, this: t.Optional[exp.Expression], is_global: bool = False ) -> exp.Expression: diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 4268f1b..2f93ee7 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -16,6 +16,8 @@ class Databricks(Spark): "DATEDIFF": parse_date_delta(exp.DateDiff), } + LOG_DEFAULTS_TO_LN = True + class Generator(Spark.Generator): TRANSFORMS = { **Spark.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b267521..839589d 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -430,6 +430,11 @@ def min_or_least(self: Generator, expression: exp.Min) -> str: return rename_func(name)(self, expression) +def max_or_greatest(self: Generator, expression: exp.Max) -> str: + name = "GREATEST" if expression.expressions else "MAX" + return rename_func(name)(self, expression) + + def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index dc0e519..a33aadc 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import typing as t from sqlglot import exp, generator, parser, tokens @@ -102,6 +101,8 @@ class Drill(Dialect): "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), } + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -154,4 +155,4 @@ class Drill(Dialect): } def normalize_func(self, name: str) -> str: - return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" + return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f1d2266..c034208 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -80,6 +80,7 @@ class DuckDB(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "~": TokenType.RLIKE, ":=": TokenType.EQ, "ATTACH": TokenType.COMMAND, "BINARY": TokenType.VARBINARY, @@ -212,5 +213,7 @@ class DuckDB(Dialect): "except": "EXCLUDE", } + LIMIT_FETCH = "LIMIT" + def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: return super().tablesample_sql(expression, seed_prefix="REPEATABLE") diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 0110eee..68137ae 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, locate_to_strposition, + max_or_greatest, min_or_least, no_ilike_sql, no_recursive_cte_sql, @@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = { "DAY": ("DATE_ADD", 1), } +TIME_DIFF_FACTOR = { + "MILLISECOND": " * 1000", + "SECOND": "", + "MINUTE": " / 60", + "HOUR": " / 3600", +} + DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") @@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() + + factor = TIME_DIFF_FACTOR.get(unit) + if factor is not None: + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})" + return f"({sec_diff}){factor}" if factor else sec_diff + sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" @@ -237,11 +253,6 @@ class Hive(Dialect): "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": locate_to_strposition, - "LOG": ( - lambda args: exp.Log.from_arg_list(args) - if len(args) > 1 - else exp.Ln.from_arg_list(args) - ), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -261,6 +272,8 @@ class Hive(Dialect): ), } + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -293,6 +306,7 @@ class Hive(Dialect): exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: var_map_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, @@ -338,6 +352,8 @@ class Hive(Dialect): exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } + LIMIT_FETCH = "LIMIT" + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1e2cfa3..5dfa811 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -3,7 +3,9 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + arrow_json_extract_scalar_sql, locate_to_strposition, + max_or_greatest, min_or_least, no_ilike_sql, no_paren_current_date_sql, @@ -288,6 +290,8 @@ class MySQL(Dialect): "SWAPS", } + LOG_DEFAULTS_TO_LN = True + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -303,7 +307,13 @@ class MySQL(Dialect): db = None else: position = None - db = self._parse_id_var() if self._match_text_seq("FROM") else None + db = None + + if self._match(TokenType.FROM): + db = self._parse_id_var() + elif self._match(TokenType.DOT): + db = target_id + target_id = self._parse_id_var() channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None @@ -384,6 +394,8 @@ class MySQL(Dialect): exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, @@ -415,6 +427,8 @@ class MySQL(Dialect): exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, } + LIMIT_FETCH = "LIMIT" + def show_sql(self, expression): this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 7028a04..fad6c4a 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -4,7 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql -from sqlglot.helper import csv, seq_get +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { @@ -13,10 +13,6 @@ PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { } -def _limit_sql(self, expression): - return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) - - def _parse_xml_table(self) -> exp.XMLTable: this = self._parse_string() @@ -89,6 +85,20 @@ class Oracle(Dialect): column.set("join_mark", self._match(TokenType.JOIN_MARKER)) return column + def _parse_hint(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.HINT): + start = self._curr + while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH): + self._advance() + + if not self._curr: + self.raise_error("Expected */ after HINT") + + end = self._tokens[self._index - 3] + return exp.Hint(expressions=[self._find_sql(start, end)]) + + return None + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -110,41 +120,20 @@ class Oracle(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, - exp.Limit: _limit_sql, - exp.Trim: trim_sql, exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), + exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", - exp.Substring: rename_func("SUBSTR"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Trim: trim_sql, + exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", } - def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: - return csv( - *sqls, - *[self.sql(sql) for sql in expression.args.get("joins") or []], - self.sql(expression, "match"), - *[self.sql(sql) for sql in expression.args.get("laterals") or []], - self.sql(expression, "where"), - self.sql(expression, "group"), - self.sql(expression, "having"), - self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) - if expression.args.get("windows") - else "", - self.sql(expression, "distribute"), - self.sql(expression, "sort"), - self.sql(expression, "cluster"), - self.sql(expression, "order"), - self.sql(expression, "offset"), # offset before limit in oracle - self.sql(expression, "limit"), - self.sql(expression, "lock"), - sep="", - ) + LIMIT_FETCH = "FETCH" def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 5f556a5..31b7e45 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, format_time_lambda, + max_or_greatest, min_or_least, no_paren_current_date_sql, no_tablesample_sql, @@ -315,6 +316,7 @@ class Postgres(Dialect): exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 799e9a6..c50961c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + max_or_greatest, min_or_least, rename_func, timestamptrunc_sql, @@ -275,6 +276,9 @@ class Snowflake(Dialect): exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), + exp.DateDiff: lambda self, e: self.func( + "DATEDIFF", e.text("unit"), e.expression, e.this + ), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -296,6 +300,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Max: max_or_greatest, exp.Min: min_or_least, } @@ -314,12 +319,6 @@ class Snowflake(Dialect): exp.SetProperty: exp.Properties.Location.UNSUPPORTED, } - def ilikeany_sql(self, expression: exp.ILikeAny) -> str: - return self.binary(expression, "ILIKE ANY") - - def likeany_sql(self, expression: exp.LikeAny) -> str: - return self.binary(expression, "LIKE ANY") - def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index ab78b6e..4091dbb 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -82,6 +82,8 @@ class SQLite(Dialect): exp.TryCast: no_trycast_sql, } + LIMIT_FETCH = "LIMIT" + def cast_sql(self, expression: exp.Cast) -> str: if expression.to.this == exp.DataType.Type.DATE: return self.func("DATE", expression.this) @@ -115,9 +117,6 @@ class SQLite(Dialect): return f"CAST({sql} AS INTEGER)" - def fetch_sql(self, expression: exp.Fetch) -> str: - return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) - # https://www.sqlite.org/lang_aggfunc.html#group_concat def groupconcat_sql(self, expression): this = expression.this diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 8bd0a0c..3d43793 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, min_or_least +from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least from sqlglot.tokens import TokenType @@ -128,6 +128,7 @@ class Teradata(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 7b52047..8e9b6c3 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -6,6 +6,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + max_or_greatest, min_or_least, parse_date_delta, rename_func, @@ -269,7 +270,6 @@ class TSQL(Dialect): # TSQL allows @, # to appear as a variable/identifier prefix SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("@") SINGLE_TOKENS.pop("#") class Parser(parser.Parser): @@ -313,6 +313,9 @@ class TSQL(Dialect): TokenType.END: lambda self: self._parse_command(), } + LOG_BASE_FIRST = False + LOG_DEFAULTS_TO_LN = True + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -435,11 +438,17 @@ class TSQL(Dialect): exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, } TRANSFORMS.pop(exp.ReturnsProperty) + LIMIT_FETCH = "FETCH" + + def offset_sql(self, expression: exp.Offset) -> str: + return f"{super().offset_sql(expression)} ROWS" + def systemtime_sql(self, expression: exp.SystemTime) -> str: kind = expression.args["kind"] if kind == "ALL": |