summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py205
-rw-r--r--sqlglot/dialects/dialect.py5
-rw-r--r--sqlglot/dialects/mysql.py15
-rw-r--r--sqlglot/dialects/postgres.py10
-rw-r--r--sqlglot/dialects/presto.py13
-rw-r--r--sqlglot/dialects/redshift.py8
-rw-r--r--sqlglot/dialects/snowflake.py31
-rw-r--r--sqlglot/dialects/spark.py1
-rw-r--r--sqlglot/dialects/spark2.py4
-rw-r--r--sqlglot/dialects/sqlite.py4
-rw-r--r--sqlglot/dialects/tsql.py1
11 files changed, 269 insertions, 28 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 52d4a88..8786063 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import logging
import re
import typing as t
@@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
+logger = logging.getLogger("sqlglot")
+
def _date_add_sql(
data_type: str, kind: str
@@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
return expression
+# https://issuetracker.google.com/issues/162294746
+# workaround for bigquery bug when grouping by an expression and then ordering
+# WITH x AS (SELECT 1 y)
+# SELECT y + 1 z
+# FROM x
+# GROUP BY x + 1
+# ORDER by z
+def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Select):
+ group = expression.args.get("group")
+ order = expression.args.get("order")
+
+ if group and order:
+ aliases = {
+ select.this: select.args["alias"]
+ for select in expression.selects
+ if isinstance(select, exp.Alias)
+ }
+
+ for e in group.expressions:
+ alias = aliases.get(e)
+
+ if alias:
+ e.replace(exp.column(alias))
+
+ return expression
+
+
+def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
+ """BigQuery doesn't allow column names when defining a CTE, so we try to push them down."""
+ if isinstance(expression, exp.CTE) and expression.alias_column_names:
+ cte_query = expression.this
+
+ if cte_query.is_star:
+ logger.warning(
+ "Can't push down CTE column names for star queries. Run the query through"
+ " the optimizer or use 'qualify' to expand the star projections first."
+ )
+ return expression
+
+ column_names = expression.alias_column_names
+ expression.args["alias"].set("columns", None)
+
+ for name, select in zip(column_names, cte_query.selects):
+ to_replace = select
+
+ if isinstance(select, exp.Alias):
+ select = select.this
+
+ # Inner aliases are shadowed by the CTE column names
+ to_replace.replace(exp.alias_(select, name))
+
+ return expression
+
+
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
+ # bigquery udfs are case sensitive
+ NORMALIZE_FUNCTIONS = False
+
TIME_MAPPING = {
"%D": "%m/%d/%y",
}
@@ -135,12 +196,16 @@ class BigQuery(Dialect):
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
# The following check is essentially a heuristic to detect tables based on whether or
# not they're qualified.
- if (
- isinstance(expression, exp.Identifier)
- and not (isinstance(expression.parent, exp.Table) and expression.parent.db)
- and not expression.meta.get("is_table")
- ):
- expression.set("this", expression.this.lower())
+ if isinstance(expression, exp.Identifier):
+ parent = expression.parent
+
+ while isinstance(parent, exp.Dot):
+ parent = parent.parent
+
+ if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get(
+ "is_table"
+ ):
+ expression.set("this", expression.this.lower())
return expression
@@ -298,10 +363,8 @@ class BigQuery(Dialect):
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
- exp.AtTimeZone: lambda self, e: self.func(
- "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
- ),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
+ exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@@ -325,7 +388,12 @@ class BigQuery(Dialect):
),
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.Select: transforms.preprocess(
- [_unqualify_unnest, transforms.eliminate_distinct_on]
+ [
+ transforms.explode_to_unnest,
+ _unqualify_unnest,
+ transforms.eliminate_distinct_on,
+ _alias_ordered_group,
+ ]
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@@ -334,7 +402,6 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -378,7 +445,121 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
+ # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
+ RESERVED_KEYWORDS = {
+ *generator.Generator.RESERVED_KEYWORDS,
+ "all",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "assert_rows_modified",
+ "at",
+ "between",
+ "by",
+ "case",
+ "cast",
+ "collate",
+ "contains",
+ "create",
+ "cross",
+ "cube",
+ "current",
+ "default",
+ "define",
+ "desc",
+ "distinct",
+ "else",
+ "end",
+ "enum",
+ "escape",
+ "except",
+ "exclude",
+ "exists",
+ "extract",
+ "false",
+ "fetch",
+ "following",
+ "for",
+ "from",
+ "full",
+ "group",
+ "grouping",
+ "groups",
+ "hash",
+ "having",
+ "if",
+ "ignore",
+ "in",
+ "inner",
+ "intersect",
+ "interval",
+ "into",
+ "is",
+ "join",
+ "lateral",
+ "left",
+ "like",
+ "limit",
+ "lookup",
+ "merge",
+ "natural",
+ "new",
+ "no",
+ "not",
+ "null",
+ "nulls",
+ "of",
+ "on",
+ "or",
+ "order",
+ "outer",
+ "over",
+ "partition",
+ "preceding",
+ "proto",
+ "qualify",
+ "range",
+ "recursive",
+ "respect",
+ "right",
+ "rollup",
+ "rows",
+ "select",
+ "set",
+ "some",
+ "struct",
+ "tablesample",
+ "then",
+ "to",
+ "treat",
+ "true",
+ "unbounded",
+ "union",
+ "unnest",
+ "using",
+ "when",
+ "where",
+ "window",
+ "with",
+ "within",
+ }
+
+ def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
+ if not isinstance(expression.parent, exp.Cast):
+ return self.func(
+ "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone"))
+ )
+ return super().attimezone_sql(expression)
+
+ def trycast_sql(self, expression: exp.TryCast) -> str:
+ return self.cast_sql(expression, safe_prefix="SAFE_")
+
+ def cte_sql(self, expression: exp.CTE) -> str:
+ if expression.alias_column_names:
+ self.unsupported("Column names in CTE definition are not supported.")
+ return super().cte_sql(expression)
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0e25b9b..d258826 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -388,6 +388,11 @@ def no_comment_column_constraint_sql(
return ""
+def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
+ self.unsupported("MAP_FROM_ENTRIES unsupported")
+ return ""
+
+
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1dd2096..5f743ee 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -132,6 +132,10 @@ class MySQL(Dialect):
"SEPARATOR": TokenType.SEPARATOR,
"ENUM": TokenType.ENUM,
"START": TokenType.BEGIN,
+ "SIGNED": TokenType.BIGINT,
+ "SIGNED INTEGER": TokenType.BIGINT,
+ "UNSIGNED": TokenType.UBIGINT,
+ "UNSIGNED INTEGER": TokenType.UBIGINT,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -441,6 +445,17 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
+ if expression.to.this == exp.DataType.Type.BIGINT:
+ to = "SIGNED"
+ elif expression.to.this == exp.DataType.Type.UBIGINT:
+ to = "UNSIGNED"
+ else:
+ return super().cast_sql(expression)
+
+ return f"CAST({self.sql(expression, 'this')} AS {to})"
+
def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 8c2a4ab..766b584 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
max_or_greatest,
min_or_least,
+ no_map_from_entries_sql,
no_paren_current_date_sql,
no_pivot_sql,
no_tablesample_sql,
@@ -346,6 +347,7 @@ class Postgres(Dialect):
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
+ exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@@ -378,3 +380,11 @@ class Postgres(Dialect):
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
+ if isinstance(expression.this, exp.Array):
+ expression = expression.copy()
+ expression.set("this", exp.paren(expression.this, copy=False))
+
+ return super().bracket_sql(expression)
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 265780e..24c439b 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -20,7 +20,7 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
-from sqlglot.helper import seq_get
+from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@@ -154,6 +154,13 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
+def _parse_element_at(args: t.List) -> exp.SafeBracket:
+ this = seq_get(args, 0)
+ index = seq_get(args, 1)
+ assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
+ return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
+
+
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@@ -201,6 +208,7 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
+ "ELEMENT_AT": _parse_element_at,
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@@ -285,6 +293,9 @@ class Presto(Dialect):
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.Right: right_to_substring_sql,
+ exp.SafeBracket: lambda self, e: self.func(
+ "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
+ ),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index db6cc3f..87be42c 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -41,8 +41,6 @@ class Redshift(Postgres):
"STRTOL": exp.FromBase.from_arg_list,
}
- CONVERT_TYPE_FIRST = True
-
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@@ -58,6 +56,12 @@ class Redshift(Postgres):
return this
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
+ to = self._parse_types()
+ self._match(TokenType.COMMA)
+ this = self._parse_bitwise()
+ return self.expression(exp.TryCast, this=this, to=to)
+
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 1f620df..a2dbfd9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -258,14 +258,29 @@ class Snowflake(Dialect):
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS,
- "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
- "SET": lambda self: self._parse_alter_table_set_tag(),
+ "SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")),
+ "UNSET": lambda self: self.expression(
+ exp.Set,
+ tag=self._match_text_seq("TAG"),
+ expressions=self._parse_csv(self._parse_id_var),
+ unset=True,
+ ),
}
- def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
- self._match_text_seq("TAG")
- parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
- return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
+ def _parse_id_var(
+ self,
+ any_token: bool = True,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
+ if self._match_text_seq("IDENTIFIER", "("):
+ identifier = (
+ super()._parse_id_var(any_token=any_token, tokens=tokens)
+ or self._parse_string()
+ )
+ self._match_r_paren()
+ return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier])
+
+ return super()._parse_id_var(any_token=any_token, tokens=tokens)
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
@@ -380,10 +395,6 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
- def settag_sql(self, expression: exp.SetTag) -> str:
- action = "UNSET" if expression.args.get("unset") else "SET"
- return f"{action} TAG {self.expressions(expression)}"
-
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index b7d1641..7a7ee01 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -43,6 +43,7 @@ class Spark(Spark2):
class Generator(Spark2.Generator):
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
+ TRANSFORMS.pop(exp.Group)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index ed6992d..3720b8d 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -231,14 +231,14 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
- def cast_sql(self, expression: exp.Cast) -> str:
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.is_type("json"):
return self.func("TO_JSON", expression.this)
- return super(Hive.Generator, self).cast_sql(expression)
+ return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 803f361..519e62a 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -133,7 +135,7 @@ class SQLite(Dialect):
LIMIT_FETCH = "LIMIT"
- def cast_sql(self, expression: exp.Cast) -> str:
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.is_type("date"):
return self.func("DATE", expression.this)
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 6d674f5..f671630 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -166,6 +166,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
class TSQL(Dialect):
+ RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"