diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 31 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 62 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 73 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 16 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 87 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 28 |
14 files changed, 292 insertions, 37 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 2e42e7d..2084681 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -15,5 +15,6 @@ from sqlglot.dialects.spark import Spark from sqlglot.dialects.sqlite import SQLite from sqlglot.dialects.starrocks import StarRocks from sqlglot.dialects.tableau import Tableau +from sqlglot.dialects.teradata import Teradata from sqlglot.dialects.trino import Trino from sqlglot.dialects.tsql import TSQL diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index f0089e1..9ddfbea 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -165,6 +165,11 @@ class BigQuery(Dialect): TokenType.TABLE, } + ID_VAR_TOKENS = { + *parser.Parser.ID_VAR_TOKENS, # type: ignore + TokenType.VALUES, + } + class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 04d46d2..1c173a4 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql +from sqlglot.errors import ParseError from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -72,6 +73,30 @@ class ClickHouse(Dialect): return this + def _parse_position(self) -> exp.Expression: + this = super()._parse_position() + # clickhouse position args are swapped + substr = this.this + this.args["this"] = this.args.get("substr") + this.args["substr"] = substr + return this + + # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ + def _parse_cte(self) -> exp.Expression: + index = self._index + try: + # WITH <identifier> AS <subquery expression> + return super()._parse_cte() + except ParseError: + # WITH <expression> AS <identifier> + self._retreat(index) + statement = self._parse_statement() + + if statement and isinstance(statement.this, exp.Alias): + self.raise_error("Expected CTE to have alias") + + return self.expression(exp.CTE, this=statement, alias=statement and statement.this) + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -110,3 +135,9 @@ class ClickHouse(Dialect): params = self.format_args(self.expressions(expression, params_name)) args = self.format_args(self.expressions(expression, args_name)) return f"({params})({args})" + + def cte_sql(self, expression: exp.CTE) -> str: + if isinstance(expression.this, exp.Alias): + return self.sql(expression, "this") + + return super().cte_sql(expression) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1c840da..0c2beba 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -33,6 +33,7 @@ class Dialects(str, Enum): TSQL = "tsql" DATABRICKS = "databricks" DRILL = "drill" + TERADATA = "teradata" class _Dialect(type): @@ -368,7 +369,7 @@ def locate_to_strposition(args): ) -def strposition_to_local_sql(self, expression): +def strposition_to_locate_sql(self, expression): args = self.format_args( expression.args.get("substr"), expression.this, expression.args.get("position") ) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ead13b1..ddfd1e8 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, no_trycast_sql, rename_func, - strposition_to_local_sql, + strposition_to_locate_sql, struct_extract_sql, timestrtotime_sql, var_map_sql, @@ -297,7 +297,7 @@ class Hive(Dialect): exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", - exp.StrPosition: strposition_to_local_sql, + exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date, exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 0fd7992..1bddfe1 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, - strposition_to_local_sql, + strposition_to_locate_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -122,6 +122,8 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "LONGTEXT": TokenType.LONGTEXT, "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, @@ -442,7 +444,7 @@ class MySQL(Dialect): exp.Trim: _trim_sql, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), - exp.StrPosition: strposition_to_local_sql, + exp.StrPosition: strposition_to_locate_sql, } ROOT_PROPERTIES = { @@ -454,6 +456,10 @@ class MySQL(Dialect): exp.LikeProperty, } + TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() + TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) + TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) + WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() def show_sql(self, expression): diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index f3fec31..6f597f1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -223,19 +223,15 @@ class Postgres(Dialect): "~~*": TokenType.ILIKE, "~*": TokenType.IRLIKE, "~": TokenType.RLIKE, - "ALWAYS": TokenType.ALWAYS, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, - "BY DEFAULT": TokenType.BY_DEFAULT, "CHARACTER VARYING": TokenType.VARCHAR, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, - "GENERATED": TokenType.GENERATED, "GRANT": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, - "IDENTITY": TokenType.IDENTITY, "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, @@ -299,6 +295,7 @@ class Postgres(Dialect): exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, exp.Trim: trim_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index e16ea1d..a79a9f9 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, no_safe_divide_sql, rename_func, - str_position_sql, struct_extract_sql, timestrtotime_sql, ) @@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression): return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _concat_ws_sql(self, expression): - sep, *args = expression.expressions - sep = self.sql(sep) - if len(args) > 1: - return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})" - return f"ARRAY_JOIN({self.sql(args[0])}, {sep})" - - def _datatype_sql(self, expression): sql = self.datatype_sql(expression) if expression.this == exp.DataType.Type.TIMESTAMPTZ: @@ -61,7 +52,7 @@ def _initcap_sql(self, expression): def _decode_sql(self, expression): _ensure_utf8(expression.args.get("charset")) - return f"FROM_UTF8({self.sql(expression, 'this')})" + return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})" def _encode_sql(self, expression): @@ -119,6 +110,38 @@ def _ensure_utf8(charset): raise UnsupportedError(f"Unsupported charset {charset}") +def _approx_percentile(args): + if len(args) == 4: + return exp.ApproxQuantile( + this=seq_get(args, 0), + weight=seq_get(args, 1), + quantile=seq_get(args, 2), + accuracy=seq_get(args, 3), + ) + if len(args) == 3: + return exp.ApproxQuantile( + this=seq_get(args, 0), + quantile=seq_get(args, 1), + accuracy=seq_get(args, 2), + ) + return exp.ApproxQuantile.from_arg_list(args) + + +def _from_unixtime(args): + if len(args) == 3: + return exp.UnixToTime( + this=seq_get(args, 0), + hours=seq_get(args, 1), + minutes=seq_get(args, 2), + ) + if len(args) == 2: + return exp.UnixToTime( + this=seq_get(args, 0), + zone=seq_get(args, 1), + ) + return exp.UnixToTime.from_arg_list(args) + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" @@ -150,19 +173,25 @@ class Presto(Dialect): ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), - "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, - "STRPOS": exp.StrPosition.from_arg_list, + "FROM_UNIXTIME": _from_unixtime, + "STRPOS": lambda args: exp.StrPosition( + this=seq_get(args, 0), + substr=seq_get(args, 1), + instance=seq_get(args, 2), + ), "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "APPROX_PERCENTILE": _approx_percentile, "FROM_HEX": exp.Unhex.from_arg_list, "TO_HEX": exp.Hex.from_arg_list, "TO_UTF8": lambda args: exp.Encode( this=seq_get(args, 0), charset=exp.Literal.string("utf-8") ), "FROM_UTF8": lambda args: exp.Decode( - this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") ), } + FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() + FUNCTION_PARSERS.pop("TRIM") class Generator(generator.Generator): @@ -194,7 +223,6 @@ class Presto(Dialect): exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.ConcatWs: _concat_ws_sql, exp.DataType: _datatype_sql, exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", @@ -209,12 +237,13 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.LogicalOr: rename_func("BOOL_OR"), exp.Quantile: _quantile_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.SortArray: _no_sort_array, - exp.StrPosition: str_position_sql, + exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", @@ -233,6 +262,7 @@ class Presto(Dialect): exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", + exp.VariancePop: rename_func("VAR_POP"), } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 27dfb93..afd7913 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, transforms from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres @@ -21,6 +23,19 @@ class Redshift(Postgres): "NVL": exp.Coalesce.from_arg_list, } + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: + this = super()._parse_types(check_func=check_func) + + if ( + isinstance(this, exp.DataType) + and this.this == exp.DataType.Type.VARCHAR + and this.expressions + and this.expressions[0] == exp.column("MAX") + ): + this.set("expressions", [exp.Var(this="MAX")]) + + return this + class Tokenizer(Postgres.Tokenizer): ESCAPES = ["\\"] @@ -52,6 +67,10 @@ class Redshift(Postgres): exp.DistStyleProperty, } + WITH_PROPERTIES = { + exp.LikeProperty, + } + TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore @@ -60,3 +79,57 @@ class Redshift(Postgres): exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.Matches: rename_func("DECODE"), } + + def values_sql(self, expression: exp.Values) -> str: + """ + Converts `VALUES...` expression into a series of unions. + + Note: If you have a lot of unions then this will result in a large number of recursive statements to + evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be + very slow. + """ + if not isinstance(expression.unnest().parent, exp.From): + return super().values_sql(expression) + rows = [tuple_exp.expressions for tuple_exp in expression.expressions] + selects = [] + for i, row in enumerate(rows): + if i == 0: + row = [ + exp.alias_(value, column_name) + for value, column_name in zip(row, expression.args["alias"].args["columns"]) + ] + selects.append(exp.Select(expressions=row)) + subquery_expression = selects[0] + if len(selects) > 1: + for select in selects[1:]: + subquery_expression = exp.union(subquery_expression, select, distinct=False) + return self.subquery_sql(subquery_expression.subquery(expression.alias)) + + def with_properties(self, properties: exp.Properties) -> str: + """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" + return self.properties(properties, prefix=" ", suffix="") + + def renametable_sql(self, expression: exp.RenameTable) -> str: + """Redshift only supports defining the table name itself (not the db) when renaming tables""" + expression = expression.copy() + target_table = expression.this + for arg in target_table.args: + if arg != "this": + target_table.set(arg, None) + this = self.sql(expression, "this") + return f"RENAME TO {this}" + + def datatype_sql(self, expression: exp.DataType) -> str: + """ + Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean + VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type + without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert + `TEXT` to `VARCHAR`. + """ + if expression.this == exp.DataType.Type.TEXT: + expression = expression.copy() + expression.set("this", exp.DataType.Type.VARCHAR) + precision = expression.args.get("expressions") + if not precision: + expression.append("expressions", exp.Var(this="MAX")) + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 24d3bdf..c44950a 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -210,6 +210,7 @@ class Snowflake(Dialect): **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DateAdd: rename_func("DATEADD"), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -218,7 +219,7 @@ class Snowflake(Dialect): exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), - exp.StrPosition: rename_func("POSITION"), + exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 7f05dea..42d34c2 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -124,6 +124,7 @@ class Spark(Hive): exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), + exp.LogicalOr: rename_func("BOOL_OR"), } TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index a0c4942..1b39449 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,6 +13,10 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType +def _fetch_sql(self, expression): + return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) + + # https://www.sqlite.org/lang_aggfunc.html#group_concat def _group_concat_sql(self, expression): this = expression.this @@ -30,6 +34,14 @@ def _group_concat_sql(self, expression): return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" +def _date_add_sql(self, expression): + modifier = expression.expression + modifier = expression.name if modifier.is_string else self.sql(modifier) + unit = expression.args.get("unit") + modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" + return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})" + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -71,6 +83,7 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + exp.DateAdd: _date_add_sql, exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, @@ -78,8 +91,11 @@ class SQLite(Dialect): exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.Levenshtein: rename_func("EDITDIST3"), exp.TableSample: no_tablesample_sql, + exp.DateStrToDate: lambda self, e: self.sql(e, "this"), + exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, exp.GroupConcat: _group_concat_sql, + exp.Fetch: _fetch_sql, } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py new file mode 100644 index 0000000..4340820 --- /dev/null +++ b/sqlglot/dialects/teradata.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from sqlglot import exp, generator, parser +from sqlglot.dialects.dialect import Dialect +from sqlglot.tokens import TokenType + + +class Teradata(Dialect): + class Parser(parser.Parser): + CHARSET_TRANSLATORS = { + "GRAPHIC_TO_KANJISJIS", + "GRAPHIC_TO_LATIN", + "GRAPHIC_TO_UNICODE", + "GRAPHIC_TO_UNICODE_PadSpace", + "KANJI1_KanjiEBCDIC_TO_UNICODE", + "KANJI1_KanjiEUC_TO_UNICODE", + "KANJI1_KANJISJIS_TO_UNICODE", + "KANJI1_SBC_TO_UNICODE", + "KANJISJIS_TO_GRAPHIC", + "KANJISJIS_TO_LATIN", + "KANJISJIS_TO_UNICODE", + "LATIN_TO_GRAPHIC", + "LATIN_TO_KANJISJIS", + "LATIN_TO_UNICODE", + "LOCALE_TO_UNICODE", + "UNICODE_TO_GRAPHIC", + "UNICODE_TO_GRAPHIC_PadGraphic", + "UNICODE_TO_GRAPHIC_VarGraphic", + "UNICODE_TO_KANJI1_KanjiEBCDIC", + "UNICODE_TO_KANJI1_KanjiEUC", + "UNICODE_TO_KANJI1_KANJISJIS", + "UNICODE_TO_KANJI1_SBC", + "UNICODE_TO_KANJISJIS", + "UNICODE_TO_LATIN", + "UNICODE_TO_LOCALE", + "UNICODE_TO_UNICODE_FoldSpace", + "UNICODE_TO_UNICODE_Fullwidth", + "UNICODE_TO_UNICODE_Halfwidth", + "UNICODE_TO_UNICODE_NFC", + "UNICODE_TO_UNICODE_NFD", + "UNICODE_TO_UNICODE_NFKC", + "UNICODE_TO_UNICODE_NFKD", + } + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, # type: ignore + "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), + } + + def _parse_translate(self, strict: bool) -> exp.Expression: + this = self._parse_conjunction() + + if not self._match(TokenType.USING): + self.raise_error("Expected USING in TRANSLATE") + + if self._match_texts(self.CHARSET_TRANSLATORS): + charset_split = self._prev.text.split("_TO_") + to = self.expression(exp.CharacterSet, this=charset_split[1]) + else: + self.raise_error("Expected a character set translator after USING in TRANSLATE") + + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + + # FROM before SET in Teradata UPDATE syntax + # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause + def _parse_update(self) -> exp.Expression: + return self.expression( + exp.Update, + **{ # type: ignore + "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), + "from": self._parse_from(), + "expressions": self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + "where": self._parse_where(), + }, + ) + + class Generator(generator.Generator): + # FROM before SET in Teradata UPDATE syntax + # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause + def update_sql(self, expression: exp.Update) -> str: + this = self.sql(expression, "this") + from_sql = self.sql(expression, "from") + set_sql = self.expressions(expression, flat=True) + where_sql = self.sql(expression, "where") + sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" + return self.prepend_ctes(expression, sql) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 465f534..9342e6b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -243,28 +243,34 @@ class TSQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "BIT": TokenType.BOOLEAN, - "REAL": TokenType.FLOAT, - "NTEXT": TokenType.TEXT, - "SMALLDATETIME": TokenType.DATETIME, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, - "TIME": TokenType.TIMESTAMP, + "DECLARE": TokenType.COMMAND, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, - "SMALLMONEY": TokenType.SMALLMONEY, + "NTEXT": TokenType.TEXT, + "NVARCHAR(MAX)": TokenType.TEXT, + "PRINT": TokenType.COMMAND, + "REAL": TokenType.FLOAT, "ROWVERSION": TokenType.ROWVERSION, - "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, - "XML": TokenType.XML, + "SMALLDATETIME": TokenType.DATETIME, + "SMALLMONEY": TokenType.SMALLMONEY, "SQL_VARIANT": TokenType.VARIANT, - "NVARCHAR(MAX)": TokenType.TEXT, - "VARCHAR(MAX)": TokenType.TEXT, + "TIME": TokenType.TIMESTAMP, "TOP": TokenType.TOP, + "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, + "VARCHAR(MAX)": TokenType.TEXT, + "XML": TokenType.XML, } class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "CHARINDEX": exp.StrPosition.from_arg_list, + "CHARINDEX": lambda args: exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -288,7 +294,7 @@ class TSQL(Dialect): } # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table - TABLE_PREFIX_TOKENS = {TokenType.HASH} + TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER} def _parse_convert(self, strict): to = self._parse_types() |