diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 20 |
14 files changed, 116 insertions, 80 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index df9065f..71977dd 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -34,7 +34,7 @@ def _date_add_sql( this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression, unit=unit) + interval = exp.Interval(this=expression.expression.copy(), unit=unit) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func @@ -76,16 +76,12 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope def _create_sql(self: generator.Generator, expression: exp.Create) -> str: kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) + if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): expression = expression.copy() expression.set("kind", "TABLE FUNCTION") - if isinstance( - expression.expression, - ( - exp.Subquery, - exp.Literal, - ), - ): + + if isinstance(expression.expression, (exp.Subquery, exp.Literal)): expression.set("expression", expression.expression.this) return self.create_sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index ce1a486..e6b7743 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -77,7 +77,7 @@ class ClickHouse(Dialect): FUNCTION_PARSERS.pop("MATCH") NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy() - NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY) + NO_PAREN_FUNCTION_PARSERS.pop("ANY") RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, @@ -355,6 +355,7 @@ class ClickHouse(Dialect): def safeconcat_sql(self, expression: exp.SafeConcat) -> str: # Clickhouse errors out if we try to cast a NULL value to TEXT + expression = expression.copy() return self.func( "CONCAT", *[ @@ -389,11 +390,7 @@ class ClickHouse(Dialect): def oncluster_sql(self, expression: exp.OnCluster) -> str: return f"ON CLUSTER {self.sql(expression, 'this')}" - def createable_sql( - self, - expression: exp.Create, - locations: dict[exp.Properties.Location, list[exp.Property]], - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: kind = self.sql(expression, "kind").upper() if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME): this_name = self.sql(expression.this, "this") @@ -402,4 +399,5 @@ class ClickHouse(Dialect): ) this_schema = self.schema_columns_sql(expression.this) return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}" + return super().createable_sql(expression, locations) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 05e81ce..1d0584c 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -346,7 +346,9 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) + exp.Like( + this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() + ) ) @@ -410,7 +412,7 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: this = self.sql(expression, "this") - struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) + struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True)) return f"{this}.{struct_key}" @@ -571,6 +573,17 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return self.sql(exp.cast(expression.this, "date")) +# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 +def encode_decode_sql( + self: Generator, expression: exp.Expression, name: str, replace: bool = True +) -> str: + charset = expression.args.get("charset") + if charset and charset.name.lower() != "utf-8": + self.unsupported(f"Expected utf-8 character set, got {charset}.") + + return self.func(name, expression.this, expression.args.get("replace") if replace else None) + + def min_or_least(self: Generator, expression: exp.Min) -> str: name = "LEAST" if expression.expressions else "MIN" return rename_func(name)(self, expression) @@ -588,7 +601,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this.expressions[0] self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") - return self.func("sum", exp.func("if", cond, 1, 0)) + return self.func("sum", exp.func("if", cond.copy(), 1, 0)) def trim_sql(self: Generator, expression: exp.Trim) -> str: @@ -625,6 +638,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: + expression = expression.copy() this, *rest_args = expression.expressions for arg in rest_args: this = exp.DPipe(this=this, expression=arg) @@ -674,11 +688,10 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp return names -def simplify_literal(expression: E, copy: bool = True) -> E: +def simplify_literal(expression: E) -> E: if not isinstance(expression.expression, exp.Literal): from sqlglot.optimizer.simplify import simplify - expression = exp.maybe_copy(expression, copy) simplify(expression.expression) return expression diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 26d09ce..1b2681d 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -20,9 +20,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" - ) + return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" return func @@ -145,7 +143,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 219b1aa..5428e86 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( binary_from_function, date_trunc_to_time, datestrtodate_sql, + encode_decode_sql, format_time_lambda, no_comment_column_constraint_sql, no_properties_sql, @@ -32,14 +33,14 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" op = "+" if isinstance(expression, exp.DateAdd) else "-" - return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" # BigQuery -> DuckDB conversion for the DATE function @@ -167,6 +168,16 @@ class DuckDB(Dialect): "XOR": binary_from_function(exp.BitwiseXor), } + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "DECODE": lambda self: self.expression( + exp.Decode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8") + ), + "ENCODE": lambda self: self.expression( + exp.Encode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8") + ), + } + TYPE_TOKENS = { *parser.Parser.TYPE_TOKENS, TokenType.UBIGINT, @@ -215,7 +226,9 @@ class DuckDB(Dialect): ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", + exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False), exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", + exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False), exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, @@ -228,8 +241,8 @@ class DuckDB(Dialect): exp.MonthsBetween: lambda self, e: self.func( "DATEDIFF", "'month'", - exp.cast(e.expression, "timestamp"), - exp.cast(e.this, "timestamp"), + exp.cast(e.expression, "timestamp", copy=True), + exp.cast(e.this, "timestamp", copy=True), ), exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, @@ -290,7 +303,7 @@ class DuckDB(Dialect): multiplier = 90 if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})" return super().interval_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 4e84085..aa4d845 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -59,7 +59,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS if expression.expression.is_number: modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) else: - modified_increment = expression.expression + modified_increment = expression.expression.copy() if multiplier != 1: modified_increment = exp.Mul( # type: ignore this=modified_increment, expression=exp.Literal.number(multiplier) @@ -272,8 +272,8 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + NO_PAREN_FUNCTION_PARSERS = { + **parser.Parser.NO_PAREN_FUNCTION_PARSERS, "TRANSFORM": lambda self: self._parse_transform(), } @@ -284,10 +284,12 @@ class Hive(Dialect): ), } - def _parse_transform(self) -> exp.Transform | exp.QueryTransform: - args = self._parse_csv(self._parse_lambda) - self._match_r_paren() + def _parse_transform(self) -> t.Optional[exp.Transform | exp.QueryTransform]: + if not self._match(TokenType.L_PAREN, advance=False): + self._retreat(self._index - 1) + return None + args = self._parse_wrapped_csv(self._parse_lambda) row_format_before = self._parse_row_format(match_row=True) record_writer = None diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a54f076..3cd99e7 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -87,9 +87,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" - ) + return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" return func @@ -522,7 +520,7 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")), + exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, @@ -556,12 +554,12 @@ class MySQL(Dialect): def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: # MySQL requires simple literal values for its LIMIT clause. - expression = simplify_literal(expression) + expression = simplify_literal(expression.copy()) return super().limit_sql(expression, top=top) def offset_sql(self, expression: exp.Offset) -> str: # MySQL requires simple literal values for its OFFSET clause. - expression = simplify_literal(expression) + expression = simplify_literal(expression.copy()) return super().offset_sql(expression) def xor_sql(self, expression: exp.Xor) -> str: diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index ef100b1..ca44b70 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -40,10 +40,12 @@ DATE_DIFF_FACTOR = { def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: + expression = expression.copy() + this = self.sql(expression, "this") unit = expression.args.get("unit") - expression = simplify_literal(expression.copy(), copy=False).expression + expression = simplify_literal(expression).expression if not isinstance(expression, exp.Literal): self.unsupported("Cannot add non literal") diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 14ec3dd..291b478 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( Dialect, binary_from_function, date_trunc_to_time, + encode_decode_sql, format_time_lambda, if_sql, left_to_substring_sql, @@ -21,7 +22,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL -from sqlglot.errors import UnsupportedError from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -41,6 +41,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + expression = expression.copy() return self.sql( exp.Join( this=exp.Unnest( @@ -59,16 +60,6 @@ def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str: - _ensure_utf8(expression.args["charset"]) - return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) - - -def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str: - _ensure_utf8(expression.args["charset"]) - return f"TO_UTF8({self.sql(expression, 'this')})" - - def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" @@ -106,14 +97,14 @@ def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDat time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") - return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto") + return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): - this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE") + this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE") return self.func( "DATE_ADD", @@ -123,11 +114,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s ) -def _ensure_utf8(charset: exp.Literal) -> None: - if charset.name.lower() != "utf-8": - raise UnsupportedError(f"Unsupported charset {charset}") - - def _approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( @@ -288,9 +274,9 @@ class Presto(Dialect): ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", - exp.Decode: _decode_sql, + exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", - exp.Encode: _encode_sql, + exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index f687ba7..cdb8d0d 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -3,7 +3,11 @@ from __future__ import annotations import typing as t from sqlglot import exp, transforms -from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func +from sqlglot.dialects.dialect import ( + concat_to_dpipe_sql, + rename_func, + ts_or_ds_to_date_sql, +) from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -13,6 +17,14 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx return f'{self.sql(expression, "this")}."{expression.expression.name}"' +def _parse_date_add(args: t.List) -> exp.DateAdd: + return exp.DateAdd( + this=exp.TsOrDsToDate(this=seq_get(args, 2)), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ) + + class Redshift(Postgres): # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -32,11 +44,8 @@ class Redshift(Postgres): expression=seq_get(args, 1), unit=exp.var("month"), ), - "DATEADD": lambda args: exp.DateAdd( - this=exp.TsOrDsToDate(this=seq_get(args, 2)), - expression=seq_get(args, 1), - unit=seq_get(args, 0), - ), + "DATEADD": _parse_date_add, + "DATE_ADD": _parse_date_add, "DATEDIFF": lambda args: exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 2)), expression=exp.TsOrDsToDate(this=seq_get(args, 1)), @@ -123,7 +132,7 @@ class Redshift(Postgres): exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", - exp.TsOrDsToDate: lambda self, e: self.sql(e.this), + exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 499e085..9733a85 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -297,9 +297,10 @@ class Snowflake(Dialect): return super()._parse_id_var(any_token=any_token, tokens=tokens) class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", "$$"] + QUOTES = ["'"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] + RAW_STRINGS = ["$$"] COMMENTS = ["--", "//", ("/*", "*/")] KEYWORDS = { @@ -363,6 +364,7 @@ class Snowflake(Dialect): exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), + exp.StartsWith: rename_func("STARTSWITH"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), @@ -382,6 +384,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), } TYPE_MAPPING = { diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index baa62e8..4f6183c 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -17,6 +17,13 @@ class StarRocks(MySQL): "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=seq_get(args, 1), expression=seq_get(args, 2), unit=seq_get(args, 0) + ), + "REGEXP": exp.RegexpLike.from_arg_list, } class Generator(MySQL.Generator): @@ -32,9 +39,11 @@ class StarRocks(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression + ), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, - exp.DateDiff: rename_func("DATEDIFF"), exp.RegexpLike: rename_func("REGEXP"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimestampTrunc: lambda self, e: self.func( diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 3fac4f5..2be1a62 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least from sqlglot.tokens import TokenType @@ -194,11 +196,7 @@ class Teradata(Dialect): return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})" - def createable_sql( - self, - expression: exp.Create, - locations: dict[exp.Properties.Location, list[exp.Property]], - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: kind = self.sql(expression, "kind").upper() if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME): this_name = self.sql(expression.this, "this") @@ -209,4 +207,5 @@ class Teradata(Dialect): ) this_schema = self.schema_columns_sql(expression.this) return f"{this_name}{this_properties}{self.sep()}{this_schema}" + return super().createable_sql(expression, locations) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 0eb0906..131307f 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -395,6 +395,20 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + """ + T-SQL supports the syntax alias = expression in the SELECT's projection list, + so we transform all parsed Selects to convert their EQ projections into Aliases. + + See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax + """ + return [ + exp.alias_(projection.expression, projection.this.this, copy=False) + if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) + else projection + for projection in super()._parse_projections() + ] + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: """Applies to SQL Server and Azure SQL Database COMMIT [ { TRAN | TRANSACTION } @@ -625,11 +639,7 @@ class TSQL(Dialect): LIMIT_FETCH = "FETCH" - def createable_sql( - self, - expression: exp.Create, - locations: dict[exp.Properties.Location, list[exp.Property]], - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: sql = self.sql(expression, "this") properties = expression.args.get("properties") |