diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 16 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 174 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 76 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 2 |
13 files changed, 356 insertions, 32 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 0816831..2e42e7d 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.drill import Drill from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.hive import Hive from sqlglot.dialects.mysql import MySQL diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5bbff9d..4550d65 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -119,6 +119,8 @@ class BigQuery(Dialect): "UNKNOWN": TokenType.NULL, "WINDOW": TokenType.WINDOW, "NOT DETERMINISTIC": TokenType.VOLATILE, + "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, } KEYWORDS.pop("DIV") @@ -204,6 +206,15 @@ class BigQuery(Dialect): EXPLICIT_UNION = True + def transaction_sql(self, *_): + return "BEGIN TRANSACTION" + + def commit_sql(self, *_): + return "COMMIT TRANSACTION" + + def rollback_sql(self, *_): + return "ROLLBACK TRANSACTION" + def in_unnest_op(self, unnest): return self.sql(unnest) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 3af08bb..8c497ab 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -32,6 +32,7 @@ class Dialects(str, Enum): TRINO = "trino" TSQL = "tsql" DATABRICKS = "databricks" + DRILL = "drill" class _Dialect(type): @@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None): return exp_class(this=this, expression=expression, unit=unit) return inner_func + + +def locate_to_strposition(args): + return exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ) + + +def strposition_to_local_sql(self, expression): + args = self.format_args( + expression.args.get("substr"), expression.this, expression.args.get("position") + ) + return f"LOCATE({args})" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py new file mode 100644 index 0000000..eb420aa --- /dev/null +++ b/sqlglot/dialects/drill.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import re + +from sqlglot import exp, generator, parser, tokens +from sqlglot.dialects.dialect import ( + Dialect, + create_with_partitions_sql, + format_time_lambda, + no_pivot_sql, + no_trycast_sql, + rename_func, + str_position_sql, +) +from sqlglot.dialects.postgres import _lateral_sql + + +def _to_timestamp(args): + # TO_TIMESTAMP accepts either a single double argument or (text, text) + if len(args) == 1 and args[0].is_number: + return exp.UnixToTime.from_arg_list(args) + return format_time_lambda(exp.StrToTime, "drill")(args) + + +def _str_to_time_sql(self, expression): + return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (Drill.time_format, Drill.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def _date_add_sql(kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() or "DAY" + expression = self.sql(expression, "expression") + return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})" + + return func + + +def if_sql(self, expression): + """ + Drill requires backticks around certain SQL reserved words, IF being one of them, This function + adds the backticks around the keyword IF. + Args: + self: The Drill dialect + expression: The input IF expression + + Returns: The expression with IF in backticks. + + """ + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) + return f"`IF`({expressions})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Drill.date_format: + return f"CAST({this} AS DATE)" + return f"TO_DATE({this}, {time_format})" + + +class Drill(Dialect): + normalize_functions = None + null_ordering = "nulls_are_last" + date_format = "'yyyy-MM-dd'" + dateint_format = "'yyyyMMdd'" + time_format = "'yyyy-MM-dd HH:mm:ss'" + + time_mapping = { + "y": "%Y", + "Y": "%Y", + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "hh": "%I", + "h": "%-I", + "mm": "%M", + "m": "%-M", + "ss": "%S", + "s": "%-S", + "SSSSSS": "%f", + "a": "%p", + "DD": "%j", + "D": "%-j", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", + "''T''": "T", + } + + class Tokenizer(tokens.Tokenizer): + QUOTES = ["'"] + IDENTIFIERS = ["`"] + ESCAPES = ["\\"] + ENCODE = "utf-8" + + class Parser(parser.Parser): + STRICT_CAST = False + + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), + } + + class Generator(generator.Generator): + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.SMALLINT: "INTEGER", + exp.DataType.Type.TINYINT: "INTEGER", + exp.DataType.Type.BINARY: "VARBINARY", + exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.NCHAR: "VARCHAR", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.DATETIME: "TIMESTAMP", + } + + ROOT_PROPERTIES = {exp.PartitionedByProperty} + + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.Lateral: _lateral_sql, + exp.ArrayContains: rename_func("REPEATED_CONTAINS"), + exp.ArraySize: rename_func("REPEATED_COUNT"), + exp.Create: create_with_partitions_sql, + exp.DateAdd: _date_add_sql("ADD"), + exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateSub: _date_add_sql("SUB"), + exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", + exp.If: if_sql, + exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", + exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.Pivot: no_pivot_sql, + exp.RegexpLike: rename_func("REGEXP_MATCHES"), + exp.StrPosition: str_position_sql, + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TryCast: no_trycast_sql, + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)", + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + } + + def normalize_func(self, name): + return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 781edff..f1da72b 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -55,13 +55,13 @@ def _array_sort_sql(self, expression): def _sort_array_sql(self, expression): this = self.sql(expression, "this") - if expression.args.get("asc") == exp.FALSE: + if expression.args.get("asc") == exp.false(): return f"ARRAY_REVERSE_SORT({this})" return f"ARRAY_SORT({this})" def _sort_array_reverse(args): - return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE) + return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) def _struct_pack_sql(self, expression): diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ed7357c..cff7139 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, format_time_lambda, if_sql, + locate_to_strposition, no_ilike_sql, no_recursive_cte_sql, no_safe_divide_sql, no_trycast_sql, rename_func, + strposition_to_local_sql, struct_extract_sql, var_map_sql, ) from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map +from sqlglot.tokens import TokenType # (FuncType, Multiplier) DATE_DELTA_INTERVAL = { @@ -181,6 +184,15 @@ class Hive(Dialect): "F": "FLOAT", "BD": "DECIMAL", } + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "ADD ARCHIVE": TokenType.COMMAND, + "ADD ARCHIVES": TokenType.COMMAND, + "ADD FILE": TokenType.COMMAND, + "ADD FILES": TokenType.COMMAND, + "ADD JAR": TokenType.COMMAND, + "ADD JARS": TokenType.COMMAND, + } class Parser(parser.Parser): STRICT_CAST = False @@ -210,11 +222,7 @@ class Hive(Dialect): "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, - "LOCATE": lambda args: exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), - ), + "LOCATE": locate_to_strposition, "LOG": ( lambda args: exp.Log.from_arg_list(args) if len(args) > 1 @@ -272,7 +280,7 @@ class Hive(Dialect): exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", - exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", + exp.StrPosition: strposition_to_local_sql, exp.StrToDate: _str_to_date, exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index e742640..93a60f4 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -5,10 +5,12 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + locate_to_strposition, no_ilike_sql, no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + strposition_to_local_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -120,6 +122,7 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -172,13 +175,18 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - STRICT_CAST = False + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, + "LOCATE": locate_to_strposition, + "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "LEFT": lambda args: exp.Substring( + this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1) + ), } FUNCTION_PARSERS = { @@ -264,6 +272,7 @@ class MySQL(Dialect): "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "NAMES": lambda self: self._parse_set_item_names(), + "TRANSACTION": lambda self: self._parse_set_transaction(), } PROFILE_TYPES = { @@ -278,39 +287,48 @@ class MySQL(Dialect): "SWAPS", } + TRANSACTION_CHARACTERISTICS = { + "ISOLATION LEVEL REPEATABLE READ", + "ISOLATION LEVEL READ COMMITTED", + "ISOLATION LEVEL READ UNCOMMITTED", + "ISOLATION LEVEL SERIALIZABLE", + "READ WRITE", + "READ ONLY", + } + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): - self._match_text(target) + self._match_text_seq(target) target_id = self._parse_id_var() else: target_id = None - log = self._parse_string() if self._match_text("IN") else None + log = self._parse_string() if self._match_text_seq("IN") else None if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: - position = self._parse_number() if self._match_text("FROM") else None + position = self._parse_number() if self._match_text_seq("FROM") else None db = None else: position = None - db = self._parse_id_var() if self._match_text("FROM") else None + db = self._parse_id_var() if self._match_text_seq("FROM") else None - channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None - like = self._parse_string() if self._match_text("LIKE") else None + like = self._parse_string() if self._match_text_seq("LIKE") else None where = self._parse_where() if this == "PROFILE": - types = self._parse_csv(self._parse_show_profile_type) - query = self._parse_number() if self._match_text("FOR", "QUERY") else None - offset = self._parse_number() if self._match_text("OFFSET") else None - limit = self._parse_number() if self._match_text("LIMIT") else None + types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES)) + query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text_seq("OFFSET") else None + limit = self._parse_number() if self._match_text_seq("LIMIT") else None else: types, query = None, None offset, limit = self._parse_oldstyle_limit() - mutex = True if self._match_text("MUTEX") else None - mutex = False if self._match_text("STATUS") else mutex + mutex = True if self._match_text_seq("MUTEX") else None + mutex = False if self._match_text_seq("STATUS") else mutex return self.expression( exp.Show, @@ -331,16 +349,16 @@ class MySQL(Dialect): **{"global": global_}, ) - def _parse_show_profile_type(self): - for type_ in self.PROFILE_TYPES: - if self._match_text(*type_.split(" ")): - return exp.Var(this=type_) + def _parse_var_from_options(self, options): + for option in options: + if self._match_text_seq(*option.split(" ")): + return exp.Var(this=option) return None def _parse_oldstyle_limit(self): limit = None offset = None - if self._match_text("LIMIT"): + if self._match_text_seq("LIMIT"): parts = self._parse_csv(self._parse_number) if len(parts) == 1: limit = parts[0] @@ -353,6 +371,9 @@ class MySQL(Dialect): return self._parse_set_item_assignment(kind=None) def _parse_set_item_assignment(self, kind): + if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): + return self._parse_set_transaction(global_=kind == "GLOBAL") + left = self._parse_primary() or self._parse_id_var() if not self._match(TokenType.EQ): self.raise_error("Expected =") @@ -381,7 +402,7 @@ class MySQL(Dialect): def _parse_set_item_names(self): charset = self._parse_string() or self._parse_id_var() - if self._match_text("COLLATE"): + if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None @@ -392,6 +413,18 @@ class MySQL(Dialect): kind="NAMES", ) + def _parse_set_transaction(self, global_=False): + self._match_text_seq("TRANSACTION") + characteristics = self._parse_csv( + lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) + ) + return self.expression( + exp.SetItem, + expressions=characteristics, + kind="TRANSACTION", + **{"global": global_}, + ) + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False @@ -411,6 +444,7 @@ class MySQL(Dialect): exp.Trim: _trim_sql, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.StrPosition: strposition_to_local_sql, } ROOT_PROPERTIES = { @@ -481,9 +515,11 @@ class MySQL(Dialect): kind = self.sql(expression, "kind") kind = f"{kind} " if kind else "" this = self.sql(expression, "this") + expressions = self.expressions(expression) collate = self.sql(expression, "collate") collate = f" COLLATE {collate}" if collate else "" - return f"{kind}{this}{collate}" + global_ = "GLOBAL " if expression.args.get("global") else "" + return f"{global_}{kind}{this}{expressions}{collate}" def set_sql(self, expression): return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 3bc1109..870d2b9 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -91,6 +91,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, "NVARCHAR2": TokenType.NVARCHAR, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 553a73b..4353164 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -164,11 +164,34 @@ class Postgres(Dialect): BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] + + CREATABLES = ( + "AGGREGATE", + "CAST", + "CONVERSION", + "COLLATION", + "DEFAULT CONVERSION", + "CONSTRAINT", + "DOMAIN", + "EXTENSION", + "FOREIGN", + "FUNCTION", + "OPERATOR", + "POLICY", + "ROLE", + "RULE", + "SEQUENCE", + "TEXT", + "TRIGGER", + "TYPE", + "UNLOGGED", + "USER", + ) + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, - "COMMENT ON": TokenType.COMMENT_ON, "IDENTITY": TokenType.IDENTITY, "GENERATED": TokenType.GENERATED, "DOUBLE PRECISION": TokenType.DOUBLE, @@ -176,6 +199,19 @@ class Postgres(Dialect): "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, "UUID": TokenType.UUID, + "TEMP": TokenType.TEMPORARY, + "BEGIN TRANSACTION": TokenType.BEGIN, + "BEGIN": TokenType.COMMAND, + "COMMENT ON": TokenType.COMMAND, + "DECLARE": TokenType.COMMAND, + "DO": TokenType.COMMAND, + "REFRESH": TokenType.COMMAND, + "REINDEX": TokenType.COMMAND, + "RESET": TokenType.COMMAND, + "REVOKE": TokenType.COMMAND, + "GRANT": TokenType.COMMAND, + **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, + **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } QUOTES = ["'", "$$"] SINGLE_TOKENS = { diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 11ea778..9d5cc11 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, ) from sqlglot.dialects.mysql import MySQL +from sqlglot.errors import UnsupportedError from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -61,8 +62,18 @@ def _initcap_sql(self, expression): return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" +def _decode_sql(self, expression): + _ensure_utf8(expression.args.get("charset")) + return f"FROM_UTF8({self.sql(expression, 'this')})" + + +def _encode_sql(self, expression): + _ensure_utf8(expression.args.get("charset")) + return f"TO_UTF8({self.sql(expression, 'this')})" + + def _no_sort_array(self, expression): - if expression.args.get("asc") == exp.FALSE: + if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: comparator = None @@ -72,7 +83,7 @@ def _no_sort_array(self, expression): def _schema_sql(self, expression): if isinstance(expression.parent, exp.Property): - columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) + columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" for schema in expression.parent.find_all(exp.Schema): @@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression): return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" +def _ensure_utf8(charset): + if charset.name.lower() != "utf-8": + raise UnsupportedError(f"Unsupported charset {charset}") + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" @@ -115,6 +131,7 @@ class Presto(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "ROW": TokenType.STRUCT, } @@ -140,6 +157,14 @@ class Presto(Dialect): "STRPOS": exp.StrPosition.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "FROM_HEX": exp.Unhex.from_arg_list, + "TO_HEX": exp.Hex.from_arg_list, + "TO_UTF8": lambda args: exp.Encode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), + "FROM_UTF8": lambda args: exp.Decode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), } class Generator(generator.Generator): @@ -187,7 +212,10 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.Encode: _encode_sql, + exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, @@ -212,7 +240,13 @@ class Presto(Dialect): exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.Unhex: rename_func("FROM_HEX"), exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", } + + def transaction_sql(self, expression): + modes = expression.args.get("modes") + modes = f" {', '.join(modes)}" if modes else "" + return f"START TRANSACTION{modes}" diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index d1aaded..a96bd80 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -148,6 +148,7 @@ class Snowflake(Dialect): **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, } + FUNCTION_PARSERS.pop("TRIM") FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, @@ -203,6 +204,7 @@ class Snowflake(Dialect): exp.StrPosition: rename_func("POSITION"), exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", } TYPE_MAPPING = { diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 8c9fb76..87b98a5 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -63,3 +63,8 @@ class SQLite(Dialect): exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, } + + def transaction_sql(self, expression): + this = expression.this + this = f" {this}" if this else "" + return f"BEGIN{this} TRANSACTION" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a233d4b..d3b83de 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -248,7 +248,7 @@ class TSQL(Dialect): def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) - this = self._parse_column() + this = self._parse_conjunction() # Retrieve length of datatype and override to default if not specified if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: |