diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 22 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 21 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 62 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 17 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 36 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 17 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 45 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 6 |
16 files changed, 242 insertions, 90 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6a19b46..7fd9e35 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re import typing as t from sqlglot import exp, generator, parser, tokens, transforms @@ -31,13 +32,6 @@ def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: return func -def _date_trunc(args: t.Sequence) -> exp.Expression: - unit = seq_get(args, 1) - if isinstance(unit, exp.Column): - unit = exp.Var(this=unit.name) - return exp.DateTrunc(this=seq_get(args, 0), expression=unit) - - def _date_add_sql( data_type: str, kind: str ) -> t.Callable[[generator.Generator, exp.Expression], str]: @@ -158,11 +152,23 @@ class BigQuery(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "DATE_TRUNC": _date_trunc, + "DATE_TRUNC": lambda args: exp.DateTrunc( + unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore + this=seq_get(args, 0), + ), "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), + expression=seq_get(args, 1), + position=seq_get(args, 2), + occurrence=seq_get(args, 3), + group=exp.Literal.number(1) + if re.compile(str(seq_get(args, 1))).groups == 1 + else None, + ), "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), @@ -214,6 +220,7 @@ class BigQuery(Dialect): exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateStrToDate: datestrtodate_sql, + exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), @@ -226,11 +233,12 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.VariancePop: rename_func("VAR_POP"), exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, - exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -251,6 +259,10 @@ class BigQuery(Dialect): exp.DataType.Type.VARCHAR: "STRING", exp.DataType.Type.NVARCHAR: "STRING", } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + } EXPLICIT_UNION = True diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2498c62..2e058e8 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -4,6 +4,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.spark import Spark from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql +from sqlglot.tokens import TokenType class Databricks(Spark): @@ -21,3 +22,11 @@ class Databricks(Spark): exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, } + + PARAMETER_TOKEN = "$" + + class Tokenizer(Spark.Tokenizer): + SINGLE_TOKENS = { + **Spark.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 176a8ce..f4e8fd4 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -215,24 +215,19 @@ DialectType = t.Union[str, Dialect, t.Type[Dialect], None] def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: - def _rename(self, expression): - args = flatten(expression.args.values()) - return f"{self.normalize_func(name)}({self.format_args(*args)})" - - return _rename + return lambda self, expression: self.func(name, *flatten(expression.args.values())) def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: if expression.args.get("accuracy"): self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") - return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" + return self.func("APPROX_COUNT_DISTINCT", expression.this) def if_sql(self: Generator, expression: exp.If) -> str: - expressions = self.format_args( - expression.this, expression.args.get("true"), expression.args.get("false") + return self.func( + "IF", expression.this, expression.args.get("true"), expression.args.get("false") ) - return f"IF({expressions})" def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: @@ -318,13 +313,13 @@ def var_map_sql( if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): self.unsupported("Cannot convert array columns into map.") - return f"{map_func_name}({self.format_args(keys, values)})" + return self.func(map_func_name, keys, values) args = [] for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) - return f"{map_func_name}({self.format_args(*args)})" + return self.func(map_func_name, *args) def format_time_lambda( @@ -400,10 +395,9 @@ def locate_to_strposition(args: t.Sequence) -> exp.Expression: def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: - args = self.format_args( - expression.args.get("substr"), expression.this, expression.args.get("position") + return self.func( + "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") ) - return f"LOCATE({args})" def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 1730eaf..e9c42e1 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -39,23 +39,6 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e return func -def if_sql(self: generator.Generator, expression: exp.If) -> str: - """ - Drill requires backticks around certain SQL reserved words, IF being one of them, This function - adds the backticks around the keyword IF. - Args: - self: The Drill dialect - expression: The input IF expression - - Returns: The expression with IF in backticks. - - """ - expressions = self.format_args( - expression.this, expression.args.get("true"), expression.args.get("false") - ) - return f"`IF`({expressions})" - - def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) @@ -134,7 +117,7 @@ class Drill(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, } TRANSFORMS = { @@ -148,7 +131,7 @@ class Drill(Dialect): exp.DateSub: _date_add_sql("SUB"), exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", - exp.If: if_sql, + exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 959e5e2..cfec9a4 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -73,11 +73,24 @@ def _datatype_sql(self, expression): return self.datatype_sql(expression) +def _regexp_extract_sql(self, expression): + bad_args = list(filter(expression.args.get, ("position", "occurrence"))) + if bad_args: + self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}") + return self.func( + "REGEXP_EXTRACT", + expression.args.get("this"), + expression.args.get("expression"), + expression.args.get("group"), + ) + + class DuckDB(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, + "ATTACH": TokenType.COMMAND, "CHARACTER VARYING": TokenType.VARCHAR, } @@ -117,7 +130,7 @@ class DuckDB(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) if isinstance(seq_get(e.expressions, 0), exp.Select) else rename_func("LIST_VALUE")(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -125,7 +138,9 @@ class DuckDB(Dialect): exp.ArraySum: rename_func("LIST_SUM"), exp.DataType: _datatype_sql, exp.DateAdd: _date_add, - exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""", + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this + ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", @@ -137,6 +152,7 @@ class DuckDB(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, + exp.RegexpExtract: _regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c558b70..ea1191e 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -43,7 +43,7 @@ def _add_date_sql(self, expression): else expression.expression ) modified_increment = exp.Literal.number(modified_increment) - return f"{func}({self.format_args(expression.this, modified_increment.this)})" + return self.func(func, expression.this, modified_increment.this) def _date_diff_sql(self, expression): @@ -66,7 +66,7 @@ def _property_sql(self, expression): def _str_to_unix(self, expression): - return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})" + return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) def _str_to_date(self, expression): @@ -312,7 +312,9 @@ class Hive(Dialect): exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsToDate: _to_date_sql, exp.TryCast: no_trycast_sql, - exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", + exp.UnixToStr: lambda self, e: self.func( + "FROM_UNIXTIME", e.this, _time_format(self, e) + ), exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", @@ -324,9 +326,9 @@ class Hive(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } def with_properties(self, properties): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index c2c2c8c..235eb77 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + rename_func, strposition_to_locate_sql, ) from sqlglot.helper import seq_get @@ -22,9 +23,8 @@ def _show_parser(*args, **kwargs): def _date_trunc_sql(self, expression): - unit = expression.name.lower() - - expr = self.sql(expression.expression) + expr = self.sql(expression, "this") + unit = expression.text("unit") if unit == "day": return f"DATE({expr})" @@ -42,7 +42,7 @@ def _date_trunc_sql(self, expression): concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: - self.unsupported("Unexpected interval unit: {unit}") + self.unsupported(f"Unexpected interval unit: {unit}") return f"DATE({expr})" return f"STR_TO_DATE({concat}, '{date_format}')" @@ -443,6 +443,10 @@ class MySQL(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index fde845e..74baa8a 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,15 +1,49 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql from sqlglot.helper import csv from sqlglot.tokens import TokenType +PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { + TokenType.COLUMN, + TokenType.RETURNING, +} + def _limit_sql(self, expression): return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) +def _parse_xml_table(self) -> exp.XMLTable: + this = self._parse_string() + + passing = None + columns = None + + if self._match_text_seq("PASSING"): + # The BY VALUE keywords are optional and are provided for semantic clarity + self._match_text_seq("BY", "VALUE") + passing = self._parse_csv( + lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS) + ) + + by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") + + if self._match_text_seq("COLUMNS"): + columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True))) + + return self.expression( + exp.XMLTable, + this=this, + passing=passing, + columns=columns, + by_ref=by_ref, + ) + + class Oracle(Dialect): # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes @@ -43,6 +77,11 @@ class Oracle(Dialect): "DECODE": exp.Matches.from_arg_list, } + FUNCTION_PARSERS: t.Dict[str, t.Callable] = { + **parser.Parser.FUNCTION_PARSERS, + "XMLTABLE": _parse_xml_table, + } + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -74,7 +113,7 @@ class Oracle(Dialect): exp.Substring: rename_func("SUBSTR"), } - def query_modifiers(self, expression, *sqls): + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("joins") or []], @@ -97,19 +136,32 @@ class Oracle(Dialect): sep="", ) - def offset_sql(self, expression): + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" - def table_sql(self, expression): - return super().table_sql(expression, sep=" ") + def table_sql(self, expression: exp.Table, sep: str = " ") -> str: + return super().table_sql(expression, sep=sep) + + def xmltable_sql(self, expression: exp.XMLTable) -> str: + this = self.sql(expression, "this") + passing = self.expressions(expression, "passing") + passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else "" + columns = self.expressions(expression, "columns") + columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else "" + by_ref = ( + f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else "" + ) + return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}" class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "COLUMNS": TokenType.COLUMN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, + "NVARCHAR2": TokenType.NVARCHAR, + "RETURNING": TokenType.RETURNING, "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, - "NVARCHAR2": TokenType.NVARCHAR, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c709665..7612330 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -58,17 +58,17 @@ def _date_diff_sql(self, expression): age = f"AGE({end}, {start})" if unit == "WEEK": - extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" elif unit == "MONTH": - extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" + unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" elif unit == "QUARTER": - extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" + unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" elif unit == "YEAR": - extract = f"EXTRACT(year FROM {age})" + unit = f"EXTRACT(year FROM {age})" else: - self.unsupported(f"Unsupported DATEDIFF unit {unit}") + unit = age - return f"CAST({extract} AS BIGINT)" + return f"CAST({unit} AS BIGINT)" def _substring_sql(self, expression): @@ -206,6 +206,8 @@ class Postgres(Dialect): } class Tokenizer(tokens.Tokenizer): + QUOTES = ["'", "$$"] + BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] @@ -236,7 +238,7 @@ class Postgres(Dialect): "UUID": TokenType.UUID, "CSTRING": TokenType.PSEUDO_TYPE, } - QUOTES = ["'", "$$"] + SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, @@ -265,6 +267,7 @@ class Postgres(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + PARAMETER_TOKEN = "$" TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6c1a474..aef9de3 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -52,7 +52,7 @@ def _initcap_sql(self, expression): def _decode_sql(self, expression): _ensure_utf8(expression.args.get("charset")) - return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})" + return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) def _encode_sql(self, expression): @@ -65,8 +65,7 @@ def _no_sort_array(self, expression): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: comparator = None - args = self.format_args(expression.this, comparator) - return f"ARRAY_SORT({args})" + return self.func("ARRAY_SORT", expression.this, comparator) def _schema_sql(self, expression): @@ -125,7 +124,7 @@ def _sequence_sql(self, expression): else: start = exp.Cast(this=start, to=to) - return f"SEQUENCE({self.format_args(start, end, step)})" + return self.func("SEQUENCE", start, end, step) def _ensure_utf8(charset): diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 813ee5f..b4268e6 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -19,6 +20,11 @@ class Redshift(Postgres): class Parser(Postgres.Parser): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, # type: ignore + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DECODE": exp.Matches.from_arg_list, "NVL": exp.Coalesce.from_arg_list, } @@ -41,7 +47,6 @@ class Redshift(Postgres): KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore - "ENCODE": TokenType.ENCODE, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, @@ -62,12 +67,15 @@ class Redshift(Postgres): PROPERTIES_LOCATION = { **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore - exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.LikeProperty: exp.Properties.Location.POST_WITH, } TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DateDiff: lambda self, e: self.func( + "DATEDIFF", e.args.get("unit") or "day", e.expression, e.this + ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 55a6bd3..bb46135 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -178,18 +178,25 @@ class Snowflake(Dialect): ), } + RANGE_PARSERS = { + **parser.Parser.RANGE_PARSERS, # type: ignore + TokenType.LIKE_ANY: lambda self, this: self._parse_escape( + self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise()) + ), + TokenType.ILIKE_ANY: lambda self, this: self._parse_escape( + self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise()) + ), + } + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } - KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "EXCLUDE": TokenType.EXCEPT, + "ILIKE ANY": TokenType.ILIKE_ANY, + "LIKE ANY": TokenType.LIKE_ANY, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, @@ -201,8 +208,14 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, } + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } + class Generator(generator.Generator): CREATE_TRANSIENT = True + PARAMETER_TOKEN = "$" TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -214,14 +227,15 @@ class Snowflake(Dialect): exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), - exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", + exp.StrPosition: lambda self, e: self.func( + "POSITION", e.args.get("substr"), e.this, e.args.get("position") + ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), } @@ -236,6 +250,12 @@ class Snowflake(Dialect): "replace": "RENAME", } + def ilikeany_sql(self, expression: exp.ILikeAny) -> str: + return self.binary(expression, "ILIKE ANY") + + def likeany_sql(self, expression: exp.LikeAny) -> str: + return self.binary(expression, "LIKE ANY") + def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 03ec211..dd3e0c8 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -86,6 +86,11 @@ class Spark(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), + unit=exp.var(seq_get(args, 0)), + ), + "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), } FUNCTION_PARSERS = { @@ -133,7 +138,7 @@ class Spark(Hive): exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.DateTrunc: rename_func("TRUNC"), + exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -142,7 +147,9 @@ class Spark(Hive): exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", - exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this + ), exp.Trim: trim_sql, exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), @@ -157,16 +164,16 @@ class Spark(Hive): TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False - CREATE_FUNCTION_AS = False + CREATE_FUNCTION_RETURN_AS = False def cast_sql(self, expression: exp.Cast) -> str: if isinstance(expression.this, exp.Cast) and expression.this.is_type( exp.DataType.Type.JSON ): schema = f"'{self.sql(expression, 'to')}'" - return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})" + return self.func("FROM_JSON", expression.this.this, schema) if expression.to.is_type(exp.DataType.Type.JSON): - return f"TO_JSON({self.sql(expression, 'this')})" + return self.func("TO_JSON", expression.this) return super(Spark.Generator, self).cast_sql(expression) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index a428dd5..86603b5 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -39,7 +39,7 @@ def _date_add_sql(self, expression): modifier = expression.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" - return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})" + return self.func("DATE", expression.this, modifier) class SQLite(Dialect): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 123da04..e3eec71 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,11 +1,33 @@ from __future__ import annotations -from sqlglot import exp, generator, parser +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect from sqlglot.tokens import TokenType class Teradata(Dialect): + class Tokenizer(tokens.Tokenizer): + # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "BYTEINT": TokenType.SMALLINT, + "SEL": TokenType.SELECT, + "INS": TokenType.INSERT, + "MOD": TokenType.MOD, + "LT": TokenType.LT, + "LE": TokenType.LTE, + "GT": TokenType.GT, + "GE": TokenType.GTE, + "^=": TokenType.NEQ, + "NE": TokenType.NEQ, + "NOT=": TokenType.NEQ, + "ST_GEOMETRY": TokenType.GEOMETRY, + } + + # teradata does not support % for modulus + SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS} + SINGLE_TOKENS.pop("%") + class Parser(parser.Parser): CHARSET_TRANSLATORS = { "GRAPHIC_TO_KANJISJIS", @@ -42,6 +64,14 @@ class Teradata(Dialect): "UNICODE_TO_UNICODE_NFKD", } + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS} + FUNC_TOKENS.remove(TokenType.REPLACE) + + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, # type: ignore + TokenType.REPLACE: lambda self: self._parse_create(), + } + FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, # type: ignore "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), @@ -76,6 +106,11 @@ class Teradata(Dialect): ) class Generator(generator.Generator): + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, # type: ignore + exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", + } + PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, @@ -93,3 +128,11 @@ class Teradata(Dialect): where_sql = self.sql(expression, "where") sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" return self.prepend_ctes(expression, sql) + + def mod_sql(self, expression: exp.Mod) -> str: + return self.binary(expression, "MOD") + + def datatype_sql(self, expression: exp.DataType) -> str: + type_sql = super().datatype_sql(expression) + prefix_sql = expression.args.get("prefix") + return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 05ba53a..b9f932b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -92,7 +92,7 @@ def _parse_eomonth(args): def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" - return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" + return self.func(func, e.text("unit"), e.expression, e.this) def _format_sql(self, e): @@ -101,7 +101,7 @@ def _format_sql(self, e): if isinstance(e, exp.NumberToStr) else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) ) - return f"FORMAT({self.format_args(e.this, fmt)})" + return self.func("FORMAT", e.this, fmt) def _string_agg_sql(self, e): @@ -408,7 +408,7 @@ class TSQL(Dialect): ): return this - expressions = self._parse_csv(self._parse_udf_kwarg) + expressions = self._parse_csv(self._parse_function_parameter) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) class Generator(generator.Generator): |