diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 21 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 48 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 18 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 25 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 18 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 41 |
14 files changed, 175 insertions, 73 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 4550d65..5b44912 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -56,12 +56,12 @@ def _derived_table_values_to_unnest(self, expression): def _returnsproperty_sql(self, expression): - value = expression.args.get("value") - if isinstance(value, exp.Schema): - value = f"{value.this} <{self.expressions(value)}>" + this = expression.this + if isinstance(this, exp.Schema): + this = f"{this.this} <{self.expressions(this)}>" else: - value = self.sql(value) - return f"RETURNS {value}" + this = self.sql(this) + return f"RETURNS {this}" def _create_sql(self, expression): @@ -142,6 +142,11 @@ class BigQuery(Dialect): ), } + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + } + FUNCTION_PARSERS.pop("TRIM") + NO_PAREN_FUNCTIONS = { **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, @@ -174,6 +179,7 @@ class BigQuery(Dialect): exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, + exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -200,9 +206,7 @@ class BigQuery(Dialect): exp.VolatilityProperty, } - WITH_PROPERTIES = { - exp.AnonymousProperty, - } + WITH_PROPERTIES = {exp.Property} EXPLICIT_UNION = True diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 332b4c1..cbed72e 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -21,14 +21,15 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "FINAL": TokenType.FINAL, + "ASOF": TokenType.ASOF, "DATETIME64": TokenType.DATETIME, - "INT8": TokenType.TINYINT, + "FINAL": TokenType.FINAL, + "FLOAT32": TokenType.FLOAT, + "FLOAT64": TokenType.DOUBLE, "INT16": TokenType.SMALLINT, "INT32": TokenType.INT, "INT64": TokenType.BIGINT, - "FLOAT32": TokenType.FLOAT, - "FLOAT64": TokenType.DOUBLE, + "INT8": TokenType.TINYINT, "TUPLE": TokenType.STRUCT, } @@ -38,6 +39,10 @@ class ClickHouse(Dialect): "MAP": parse_var_map, } + JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} + + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} + def _parse_table(self, schema=False): this = super()._parse_table(schema) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8c497ab..c87f8d8 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -289,19 +289,19 @@ def struct_extract_sql(self, expression): return f"{this}.{struct_key}" -def var_map_sql(self, expression): +def var_map_sql(self, expression, map_func_name="MAP"): keys = expression.args["keys"] values = expression.args["values"] if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): self.unsupported("Cannot convert array columns into map.") - return f"MAP({self.format_args(keys, values)})" + return f"{map_func_name}({self.format_args(keys, values)})" args = [] for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) - return f"MAP({self.format_args(*args)})" + return f"{map_func_name}({self.format_args(*args)})" def format_time_lambda(exp_class, dialect, default=None): @@ -336,18 +336,13 @@ def create_with_partitions_sql(self, expression): if has_schema and is_partitionable: expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) - value = prop and prop.args.get("value") - if prop and not isinstance(value, exp.Schema): + this = prop and prop.this + if prop and not isinstance(this, exp.Schema): schema = expression.this - columns = {v.name.upper() for v in value.expressions} + columns = {v.name.upper() for v in this.expressions} partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set( - "expressions", - [e for e in schema.expressions if e not in partitions], - ) - prop.replace( - exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) - ) + schema.set("expressions", [e for e in schema.expressions if e not in partitions]) + prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) expression.set("this", schema) return self.create_sql(expression) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index eb420aa..358eced 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -153,7 +153,7 @@ class Drill(Dialect): exp.If: if_sql, exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index cff7139..cbb39c2 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -61,9 +61,7 @@ def _array_sort(self, expression): def _property_sql(self, expression): - key = expression.name - value = self.sql(expression, "value") - return f"'{key}'={value}" + return f"'{expression.name}'={self.sql(expression, 'value')}" def _str_to_unix(self, expression): @@ -250,7 +248,7 @@ class Hive(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, **transforms.UNALIAS_GROUP, # type: ignore - exp.AnonymousProperty: _property_sql, + exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayConcat: rename_func("CONCAT"), @@ -262,7 +260,7 @@ class Hive(Dialect): exp.DateStrToDate: rename_func("TO_DATE"), exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", + exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}", exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, @@ -285,7 +283,7 @@ class Hive(Dialect): exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", + exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -298,11 +296,11 @@ class Hive(Dialect): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.NumberToStr: rename_func("FORMAT_NUMBER"), } - WITH_PROPERTIES = {exp.AnonymousProperty} + WITH_PROPERTIES = {exp.Property} ROOT_PROPERTIES = { exp.PartitionedByProperty, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 93a60f4..7627b6e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -453,6 +453,7 @@ class MySQL(Dialect): exp.CharacterSetProperty, exp.CollateProperty, exp.SchemaCommentProperty, + exp.LikeProperty, } WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 870d2b9..ceaf9ba 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,7 +1,7 @@ from __future__ import annotations -from sqlglot import exp, generator, tokens, transforms -from sqlglot.dialects.dialect import Dialect, no_ilike_sql +from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func from sqlglot.helper import csv from sqlglot.tokens import TokenType @@ -37,6 +37,12 @@ class Oracle(Dialect): "YYYY": "%Y", # 2015 } + class Parser(parser.Parser): + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "DECODE": exp.Matches.from_arg_list, + } + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -58,6 +64,7 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 4353164..1cb5025 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -74,6 +74,27 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" +def _string_agg_sql(self, expression): + expression = expression.copy() + separator = expression.args.get("separator") or exp.Literal.string(",") + + order = "" + this = expression.this + if isinstance(this, exp.Order): + if this.this: + this = this.this + this.pop() + order = self.sql(expression.this) # Order has a leading space + + return f"STRING_AGG({self.format_args(this, separator)}{order})" + + +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return f"{self.expressions(expression, flat=True)}[]" + return self.datatype_sql(expression) + + def _auto_increment_to_serial(expression): auto = expression.find(exp.AutoIncrementColumnConstraint) @@ -191,25 +212,27 @@ class Postgres(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, - "BY DEFAULT": TokenType.BY_DEFAULT, - "IDENTITY": TokenType.IDENTITY, - "GENERATED": TokenType.GENERATED, - "DOUBLE PRECISION": TokenType.DOUBLE, - "BIGSERIAL": TokenType.BIGSERIAL, - "SERIAL": TokenType.SERIAL, - "SMALLSERIAL": TokenType.SMALLSERIAL, - "UUID": TokenType.UUID, - "TEMP": TokenType.TEMPORARY, - "BEGIN TRANSACTION": TokenType.BEGIN, "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, + "BIGSERIAL": TokenType.BIGSERIAL, + "BY DEFAULT": TokenType.BY_DEFAULT, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, + "DOUBLE PRECISION": TokenType.DOUBLE, + "GENERATED": TokenType.GENERATED, + "GRANT": TokenType.COMMAND, + "HSTORE": TokenType.HSTORE, + "IDENTITY": TokenType.IDENTITY, + "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, "REVOKE": TokenType.COMMAND, - "GRANT": TokenType.COMMAND, + "SERIAL": TokenType.SERIAL, + "SMALLSERIAL": TokenType.SMALLSERIAL, + "TEMP": TokenType.TEMPORARY, + "UUID": TokenType.UUID, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } @@ -265,4 +288,7 @@ class Postgres(Dialect): exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", + exp.DataType: _datatype_sql, + exp.GroupConcat: _string_agg_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9d5cc11..1a09037 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -171,16 +171,7 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") - ROOT_PROPERTIES = { - exp.SchemaCommentProperty, - } - - WITH_PROPERTIES = { - exp.PartitionedByProperty, - exp.FileFormatProperty, - exp.AnonymousProperty, - exp.TableFormatProperty, - } + ROOT_PROPERTIES = {exp.SchemaCommentProperty} TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -231,7 +222,8 @@ class Presto(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", + exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.TimeStrToDate: _date_parse_sql, exp.TimeStrToTime: _date_parse_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a9b12fb..cd50979 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sqlglot import exp +from sqlglot import exp, transforms from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -18,12 +18,14 @@ class Redshift(Postgres): KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore + "COPY": TokenType.COMMAND, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, + "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, "SIMILAR TO": TokenType.SIMILAR_TO, } @@ -35,3 +37,17 @@ class Redshift(Postgres): exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } + + ROOT_PROPERTIES = { + exp.DistKeyProperty, + exp.SortKeyProperty, + exp.DistStyleProperty, + } + + TRANSFORMS = { + **Postgres.Generator.TRANSFORMS, # type: ignore + **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", + exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.DistStyleProperty: lambda self, e: self.naked_property(e), + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a96bd80..46155ff 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, inline_array_sql, rename_func, + var_map_sql, ) from sqlglot.expressions import Literal from sqlglot.helper import seq_get @@ -100,6 +101,14 @@ def _parse_date_part(self): return self.expression(exp.Extract, this=this, expression=expression) +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return "ARRAY" + elif expression.this == exp.DataType.Type.MAP: + return "OBJECT" + return self.datatype_sql(expression) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -142,6 +151,8 @@ class Snowflake(Dialect): "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, + "DECODE": exp.Matches.from_arg_list, + "OBJECT_CONSTRUCT": parser.parse_var_map, } FUNCTION_PARSERS = { @@ -195,16 +206,20 @@ class Snowflake(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), + exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Matches: rename_func("DECODE"), + exp.StrPosition: rename_func("POSITION"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Array: inline_array_sql, - exp.StrPosition: rename_func("POSITION"), - exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.UnixToTime: _unix_to_time_sql, } TYPE_MAPPING = { diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 4e404b8..16083d1 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -98,7 +98,7 @@ class Spark(Hive): TRANSFORMS = { **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", + exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 87b98a5..bbb752b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,6 +13,23 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType +# https://www.sqlite.org/lang_aggfunc.html#group_concat +def _group_concat_sql(self, expression): + this = expression.this + distinct = expression.find(exp.Distinct) + if distinct: + this = distinct.expressions[0] + distinct = "DISTINCT " + + if isinstance(expression.this, exp.Order): + self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") + if expression.this.this and not distinct: + this = expression.this.this + + separator = expression.args.get("separator") + return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -62,6 +79,7 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, + exp.GroupConcat: _group_concat_sql, } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index d3b83de..07ce38b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -17,6 +17,7 @@ FULL_FORMAT_TIME_MAPPING = { "mm": "%B", "m": "%B", } + DATE_DELTA_INTERVAL = { "year": "year", "yyyy": "year", @@ -37,11 +38,12 @@ DATE_DELTA_INTERVAL = { DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") + # N = Numeric, C=Currency TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} -def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): +def _format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( this=seq_get(args, 1), @@ -58,7 +60,7 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def parse_format(args): +def _parse_format(args): fmt = seq_get(args, 1) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) if number_fmt: @@ -78,7 +80,7 @@ def generate_date_delta_with_unit_sql(self, e): return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" -def generate_format_sql(self, e): +def _format_sql(self, e): fmt = ( e.args["format"] if isinstance(e, exp.NumberToStr) @@ -87,6 +89,28 @@ def generate_format_sql(self, e): return f"FORMAT({self.format_args(e.this, fmt)})" +def _string_agg_sql(self, e): + e = e.copy() + + this = e.this + distinct = e.find(exp.Distinct) + if distinct: + # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression + self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") + this = distinct.expressions[0] + distinct.pop() + + order = "" + if isinstance(e.this, exp.Order): + if e.this.this: + this = e.this.this + e.this.this.pop() + order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space + + separator = e.args.get("separator") or exp.Literal.string(",") + return f"STRING_AGG({self.format_args(this, separator)}){order}" + + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" @@ -228,14 +252,14 @@ class TSQL(Dialect): "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), - "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), - "DATEPART": tsql_format_time_lambda(exp.TimeToStr), + "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), + "DATEPART": _format_time_lambda(exp.TimeToStr), "GETDATE": exp.CurrentDate.from_arg_list, "IIF": exp.If.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "FORMAT": parse_format, + "FORMAT": _parse_format, } VAR_LENGTH_DATATYPES = { @@ -298,6 +322,7 @@ class TSQL(Dialect): exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.If: rename_func("IIF"), - exp.NumberToStr: generate_format_sql, - exp.TimeToStr: generate_format_sql, + exp.NumberToStr: _format_sql, + exp.TimeToStr: _format_sql, + exp.GroupConcat: _string_agg_sql, } |