summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py11
-rw-r--r--sqlglot/dialects/dialect.py16
-rw-r--r--sqlglot/dialects/drill.py174
-rw-r--r--sqlglot/dialects/duckdb.py4
-rw-r--r--sqlglot/dialects/hive.py20
-rw-r--r--sqlglot/dialects/mysql.py76
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/postgres.py38
-rw-r--r--sqlglot/dialects/presto.py38
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/tsql.py2
13 files changed, 356 insertions, 32 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 0816831..2e42e7d 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
+from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5bbff9d..4550d65 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -119,6 +119,8 @@ class BigQuery(Dialect):
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
+ "BEGIN": TokenType.COMMAND,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@@ -204,6 +206,15 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
+ def transaction_sql(self, *_):
+ return "BEGIN TRANSACTION"
+
+ def commit_sql(self, *_):
+ return "COMMIT TRANSACTION"
+
+ def rollback_sql(self, *_):
+ return "ROLLBACK TRANSACTION"
+
def in_unnest_op(self, unnest):
return self.sql(unnest)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 3af08bb..8c497ab 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -32,6 +32,7 @@ class Dialects(str, Enum):
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
+ DRILL = "drill"
class _Dialect(type):
@@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
+
+
+def locate_to_strposition(args):
+ return exp.StrPosition(
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ )
+
+
+def strposition_to_local_sql(self, expression):
+ args = self.format_args(
+ expression.args.get("substr"), expression.this, expression.args.get("position")
+ )
+ return f"LOCATE({args})"
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
new file mode 100644
index 0000000..eb420aa
--- /dev/null
+++ b/sqlglot/dialects/drill.py
@@ -0,0 +1,174 @@
+from __future__ import annotations
+
+import re
+
+from sqlglot import exp, generator, parser, tokens
+from sqlglot.dialects.dialect import (
+ Dialect,
+ create_with_partitions_sql,
+ format_time_lambda,
+ no_pivot_sql,
+ no_trycast_sql,
+ rename_func,
+ str_position_sql,
+)
+from sqlglot.dialects.postgres import _lateral_sql
+
+
+def _to_timestamp(args):
+ # TO_TIMESTAMP accepts either a single double argument or (text, text)
+ if len(args) == 1 and args[0].is_number:
+ return exp.UnixToTime.from_arg_list(args)
+ return format_time_lambda(exp.StrToTime, "drill")(args)
+
+
+def _str_to_time_sql(self, expression):
+ return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
+
+
+def _ts_or_ds_to_date_sql(self, expression):
+ time_format = self.format_time(expression)
+ if time_format and time_format not in (Drill.time_format, Drill.date_format):
+ return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
+ return f"CAST({self.sql(expression, 'this')} AS DATE)"
+
+
+def _date_add_sql(kind):
+ def func(self, expression):
+ this = self.sql(expression, "this")
+ unit = expression.text("unit").upper() or "DAY"
+ expression = self.sql(expression, "expression")
+ return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
+
+ return func
+
+
+def if_sql(self, expression):
+ """
+ Drill requires backticks around certain SQL reserved words, IF being one of them, This function
+ adds the backticks around the keyword IF.
+ Args:
+ self: The Drill dialect
+ expression: The input IF expression
+
+ Returns: The expression with IF in backticks.
+
+ """
+ expressions = self.format_args(
+ expression.this, expression.args.get("true"), expression.args.get("false")
+ )
+ return f"`IF`({expressions})"
+
+
+def _str_to_date(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format == Drill.date_format:
+ return f"CAST({this} AS DATE)"
+ return f"TO_DATE({this}, {time_format})"
+
+
+class Drill(Dialect):
+ normalize_functions = None
+ null_ordering = "nulls_are_last"
+ date_format = "'yyyy-MM-dd'"
+ dateint_format = "'yyyyMMdd'"
+ time_format = "'yyyy-MM-dd HH:mm:ss'"
+
+ time_mapping = {
+ "y": "%Y",
+ "Y": "%Y",
+ "YYYY": "%Y",
+ "yyyy": "%Y",
+ "YY": "%y",
+ "yy": "%y",
+ "MMMM": "%B",
+ "MMM": "%b",
+ "MM": "%m",
+ "M": "%-m",
+ "dd": "%d",
+ "d": "%-d",
+ "HH": "%H",
+ "H": "%-H",
+ "hh": "%I",
+ "h": "%-I",
+ "mm": "%M",
+ "m": "%-M",
+ "ss": "%S",
+ "s": "%-S",
+ "SSSSSS": "%f",
+ "a": "%p",
+ "DD": "%j",
+ "D": "%-j",
+ "E": "%a",
+ "EE": "%a",
+ "EEE": "%a",
+ "EEEE": "%A",
+ "''T''": "T",
+ }
+
+ class Tokenizer(tokens.Tokenizer):
+ QUOTES = ["'"]
+ IDENTIFIERS = ["`"]
+ ESCAPES = ["\\"]
+ ENCODE = "utf-8"
+
+ class Parser(parser.Parser):
+ STRICT_CAST = False
+
+ FUNCTIONS = {
+ **parser.Parser.FUNCTIONS,
+ "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
+ "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
+ }
+
+ class Generator(generator.Generator):
+ TYPE_MAPPING = {
+ **generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.SMALLINT: "INTEGER",
+ exp.DataType.Type.TINYINT: "INTEGER",
+ exp.DataType.Type.BINARY: "VARBINARY",
+ exp.DataType.Type.TEXT: "VARCHAR",
+ exp.DataType.Type.NCHAR: "VARCHAR",
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
+ }
+
+ ROOT_PROPERTIES = {exp.PartitionedByProperty}
+
+ TRANSFORMS = {
+ **generator.Generator.TRANSFORMS,
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.Lateral: _lateral_sql,
+ exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
+ exp.ArraySize: rename_func("REPEATED_COUNT"),
+ exp.Create: create_with_partitions_sql,
+ exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.DateSub: _date_add_sql("SUB"),
+ exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
+ exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
+ exp.If: if_sql,
+ exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
+ exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.Pivot: no_pivot_sql,
+ exp.RegexpLike: rename_func("REGEXP_MATCHES"),
+ exp.StrPosition: str_position_sql,
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TryCast: no_trycast_sql,
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
+ }
+
+ def normalize_func(self, name):
+ return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 781edff..f1da72b 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
def _sort_array_sql(self, expression):
this = self.sql(expression, "this")
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
- return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
+ return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression):
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index ed7357c..cff7139 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
+ locate_to_strposition,
no_ilike_sql,
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
rename_func,
+ strposition_to_local_sql,
struct_extract_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
+from sqlglot.tokens import TokenType
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@@ -181,6 +184,15 @@ class Hive(Dialect):
"F": "FLOAT",
"BD": "DECIMAL",
}
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ "ADD ARCHIVE": TokenType.COMMAND,
+ "ADD ARCHIVES": TokenType.COMMAND,
+ "ADD FILE": TokenType.COMMAND,
+ "ADD FILES": TokenType.COMMAND,
+ "ADD JAR": TokenType.COMMAND,
+ "ADD JARS": TokenType.COMMAND,
+ }
class Parser(parser.Parser):
STRICT_CAST = False
@@ -210,11 +222,7 @@ class Hive(Dialect):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
- "LOCATE": lambda args: exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
- ),
+ "LOCATE": locate_to_strposition,
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
@@ -272,7 +280,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
- exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
+ exp.StrPosition: strposition_to_local_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index e742640..93a60f4 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -5,10 +5,12 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ locate_to_strposition,
no_ilike_sql,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
+ strposition_to_local_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -120,6 +122,7 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@@ -172,13 +175,18 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
- STRICT_CAST = False
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
+ "LOCATE": locate_to_strposition,
+ "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
+ "LEFT": lambda args: exp.Substring(
+ this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
+ ),
}
FUNCTION_PARSERS = {
@@ -264,6 +272,7 @@ class MySQL(Dialect):
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
+ "TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@@ -278,39 +287,48 @@ class MySQL(Dialect):
"SWAPS",
}
+ TRANSACTION_CHARACTERISTICS = {
+ "ISOLATION LEVEL REPEATABLE READ",
+ "ISOLATION LEVEL READ COMMITTED",
+ "ISOLATION LEVEL READ UNCOMMITTED",
+ "ISOLATION LEVEL SERIALIZABLE",
+ "READ WRITE",
+ "READ ONLY",
+ }
+
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
- self._match_text(target)
+ self._match_text_seq(target)
target_id = self._parse_id_var()
else:
target_id = None
- log = self._parse_string() if self._match_text("IN") else None
+ log = self._parse_string() if self._match_text_seq("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
- position = self._parse_number() if self._match_text("FROM") else None
+ position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
position = None
- db = self._parse_id_var() if self._match_text("FROM") else None
+ db = self._parse_id_var() if self._match_text_seq("FROM") else None
- channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
+ channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
- like = self._parse_string() if self._match_text("LIKE") else None
+ like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
- types = self._parse_csv(self._parse_show_profile_type)
- query = self._parse_number() if self._match_text("FOR", "QUERY") else None
- offset = self._parse_number() if self._match_text("OFFSET") else None
- limit = self._parse_number() if self._match_text("LIMIT") else None
+ types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
+ query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
+ offset = self._parse_number() if self._match_text_seq("OFFSET") else None
+ limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
- mutex = True if self._match_text("MUTEX") else None
- mutex = False if self._match_text("STATUS") else mutex
+ mutex = True if self._match_text_seq("MUTEX") else None
+ mutex = False if self._match_text_seq("STATUS") else mutex
return self.expression(
exp.Show,
@@ -331,16 +349,16 @@ class MySQL(Dialect):
**{"global": global_},
)
- def _parse_show_profile_type(self):
- for type_ in self.PROFILE_TYPES:
- if self._match_text(*type_.split(" ")):
- return exp.Var(this=type_)
+ def _parse_var_from_options(self, options):
+ for option in options:
+ if self._match_text_seq(*option.split(" ")):
+ return exp.Var(this=option)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
- if self._match_text("LIMIT"):
+ if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
@@ -353,6 +371,9 @@ class MySQL(Dialect):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
+ if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
+ return self._parse_set_transaction(global_=kind == "GLOBAL")
+
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
@@ -381,7 +402,7 @@ class MySQL(Dialect):
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
- if self._match_text("COLLATE"):
+ if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
@@ -392,6 +413,18 @@ class MySQL(Dialect):
kind="NAMES",
)
+ def _parse_set_transaction(self, global_=False):
+ self._match_text_seq("TRANSACTION")
+ characteristics = self._parse_csv(
+ lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
+ )
+ return self.expression(
+ exp.SetItem,
+ expressions=characteristics,
+ kind="TRANSACTION",
+ **{"global": global_},
+ )
+
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
@@ -411,6 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.StrPosition: strposition_to_local_sql,
}
ROOT_PROPERTIES = {
@@ -481,9 +515,11 @@ class MySQL(Dialect):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
+ expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
- return f"{kind}{this}{collate}"
+ global_ = "GLOBAL " if expression.args.get("global") else ""
+ return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 3bc1109..870d2b9 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -91,6 +91,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 553a73b..4353164 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -164,11 +164,34 @@ class Postgres(Dialect):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
+
+ CREATABLES = (
+ "AGGREGATE",
+ "CAST",
+ "CONVERSION",
+ "COLLATION",
+ "DEFAULT CONVERSION",
+ "CONSTRAINT",
+ "DOMAIN",
+ "EXTENSION",
+ "FOREIGN",
+ "FUNCTION",
+ "OPERATOR",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SEQUENCE",
+ "TEXT",
+ "TRIGGER",
+ "TYPE",
+ "UNLOGGED",
+ "USER",
+ )
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
- "COMMENT ON": TokenType.COMMENT_ON,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
@@ -176,6 +199,19 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
+ "TEMP": TokenType.TEMPORARY,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
+ "BEGIN": TokenType.COMMAND,
+ "COMMENT ON": TokenType.COMMAND,
+ "DECLARE": TokenType.COMMAND,
+ "DO": TokenType.COMMAND,
+ "REFRESH": TokenType.COMMAND,
+ "REINDEX": TokenType.COMMAND,
+ "RESET": TokenType.COMMAND,
+ "REVOKE": TokenType.COMMAND,
+ "GRANT": TokenType.COMMAND,
+ **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
+ **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 11ea778..9d5cc11 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
+from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
+def _decode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"FROM_UTF8({self.sql(expression, 'this')})"
+
+
+def _encode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"TO_UTF8({self.sql(expression, 'this')})"
+
+
def _no_sort_array(self, expression):
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
@@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
- columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
+ columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
@@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+def _ensure_utf8(charset):
+ if charset.name.lower() != "utf-8":
+ raise UnsupportedError(f"Unsupported charset {charset}")
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}
@@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "FROM_HEX": exp.Unhex.from_arg_list,
+ "TO_HEX": exp.Hex.from_arg_list,
+ "TO_UTF8": lambda args: exp.Encode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
}
class Generator(generator.Generator):
@@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
+ exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
+ exp.Encode: _encode_sql,
+ exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}
+
+ def transaction_sql(self, expression):
+ modes = expression.args.get("modes")
+ modes = f" {', '.join(modes)}" if modes else ""
+ return f"START TRANSACTION{modes}"
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index d1aaded..a96bd80 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -148,6 +148,7 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
+ FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
@@ -203,6 +204,7 @@ class Snowflake(Dialect):
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 8c9fb76..87b98a5 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -63,3 +63,8 @@ class SQLite(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}
+
+ def transaction_sql(self, expression):
+ this = expression.this
+ this = f" {this}" if this else ""
+ return f"BEGIN{this} TRANSACTION"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a233d4b..d3b83de 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -248,7 +248,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
- this = self._parse_column()
+ this = self._parse_conjunction()
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: