summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-03 09:12:28 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-03 09:12:28 +0000
commit67c28dbe67209effad83d93b850caba5ee1e20e3 (patch)
treedffdfbfb4f0899c92a4c978e6eac55af2ff76367 /sqlglot
parentReleasing debian version 11.5.2-1. (diff)
downloadsqlglot-67c28dbe67209effad83d93b850caba5ee1e20e3.tar.xz
sqlglot-67c28dbe67209effad83d93b850caba5ee1e20e3.zip
Merging upstream version 11.7.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py5
-rw-r--r--sqlglot/dataframe/sql/column.py2
-rw-r--r--sqlglot/dataframe/sql/dataframe.py151
-rw-r--r--sqlglot/dataframe/sql/functions.py15
-rw-r--r--sqlglot/dataframe/sql/readwriter.py12
-rw-r--r--sqlglot/dialects/bigquery.py61
-rw-r--r--sqlglot/dialects/clickhouse.py7
-rw-r--r--sqlglot/dialects/databricks.py8
-rw-r--r--sqlglot/dialects/dialect.py37
-rw-r--r--sqlglot/dialects/drill.py4
-rw-r--r--sqlglot/dialects/duckdb.py11
-rw-r--r--sqlglot/dialects/hive.py64
-rw-r--r--sqlglot/dialects/mysql.py63
-rw-r--r--sqlglot/dialects/oracle.py32
-rw-r--r--sqlglot/dialects/postgres.py36
-rw-r--r--sqlglot/dialects/presto.py85
-rw-r--r--sqlglot/dialects/redshift.py17
-rw-r--r--sqlglot/dialects/snowflake.py59
-rw-r--r--sqlglot/dialects/spark.py54
-rw-r--r--sqlglot/dialects/sqlite.py10
-rw-r--r--sqlglot/dialects/starrocks.py8
-rw-r--r--sqlglot/dialects/tableau.py8
-rw-r--r--sqlglot/dialects/teradata.py21
-rw-r--r--sqlglot/dialects/tsql.py37
-rw-r--r--sqlglot/expressions.py286
-rw-r--r--sqlglot/generator.py205
-rw-r--r--sqlglot/helper.py34
-rw-r--r--sqlglot/lineage.py53
-rw-r--r--sqlglot/optimizer/annotate_types.py7
-rw-r--r--sqlglot/optimizer/normalize.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py5
-rw-r--r--sqlglot/optimizer/qualify_tables.py3
-rw-r--r--sqlglot/optimizer/simplify.py35
-rw-r--r--sqlglot/parser.py382
-rw-r--r--sqlglot/schema.py156
-rw-r--r--sqlglot/tokens.py163
-rw-r--r--sqlglot/transforms.py135
-rw-r--r--sqlglot/trie.py5
38 files changed, 1650 insertions, 630 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 1feb464..42d89d1 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -21,10 +21,12 @@ from sqlglot.expressions import (
Expression as Expression,
alias_ as alias,
and_ as and_,
+ cast as cast,
column as column,
condition as condition,
except_ as except_,
from_ as from_,
+ func as func,
intersect as intersect,
maybe_parse as maybe_parse,
not_ as not_,
@@ -33,6 +35,7 @@ from sqlglot.expressions import (
subquery as subquery,
table_ as table,
to_column as to_column,
+ to_identifier as to_identifier,
to_table as to_table,
union as union,
)
@@ -47,7 +50,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.5.2"
+__version__ = "11.7.1"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index 609b2a4..a8b89d1 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -176,7 +176,7 @@ class Column:
return isinstance(self.expression, exp.Column)
@property
- def column_expression(self) -> exp.Column:
+ def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
return self.expression.unalias()
@property
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 93bdf75..f3a6f6f 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -16,7 +16,7 @@ from sqlglot.dataframe.sql.readwriter import DataFrameWriter
from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
-from sqlglot.helper import ensure_list, object_to_dict
+from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
if t.TYPE_CHECKING:
@@ -146,9 +146,9 @@ class DataFrame:
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
- def _ensure_and_normalize_cols(self, cols):
+ def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
cols = self._ensure_list_of_columns(cols)
- normalize(self.spark, self.expression, cols)
+ normalize(self.spark, expression or self.expression, cols)
return cols
def _ensure_and_normalize_col(self, col):
@@ -355,12 +355,20 @@ class DataFrame:
cols = self._ensure_and_normalize_cols(cols)
kwargs["append"] = kwargs.get("append", False)
if self.expression.args.get("joins"):
- ambiguous_cols = [col for col in cols if not col.column_expression.table]
+ ambiguous_cols = [
+ col
+ for col in cols
+ if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
+ ]
if ambiguous_cols:
join_table_identifiers = [
x.this for x in get_tables_from_expression_with_join(self.expression)
]
cte_names_in_join = [x.this for x in join_table_identifiers]
+ # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
+ # and therefore we allow multiple columns with the same name in the result. This matches the behavior
+ # of Spark.
+ resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
cte
@@ -368,13 +376,14 @@ class DataFrame:
if cte.alias_or_name in cte_names_in_join
and ambiguous_col.alias_or_name in cte.this.named_selects
]
- # If the select column does not specify a table and there is a join
- # then we assume they are referring to the left table
- if len(ctes_with_column) > 1:
- table_identifier = self.expression.args["from"].args["expressions"][0].this
+ # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
+ # use the same CTE we used before
+ cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
+ if cte:
+ resolved_column_position[ambiguous_col] += 1
else:
- table_identifier = ctes_with_column[0].args["alias"].this
- ambiguous_col.expression.set("table", table_identifier)
+ cte = ctes_with_column[resolved_column_position[ambiguous_col]]
+ ambiguous_col.expression.set("table", cte.alias_or_name)
return self.copy(
expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
)
@@ -416,59 +425,87 @@ class DataFrame:
**kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
- pre_join_self_latest_cte_name = self.latest_cte_name
- columns = self._ensure_and_normalize_cols(on)
- join_type = how.replace("_", " ")
- if isinstance(columns[0].expression, exp.Column):
- join_columns = [
- Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
+ join_columns = self._ensure_list_of_columns(on)
+ # We will determine actual "join on" expression later so we don't provide it at first
+ join_expression = self.expression.join(
+ other_df.latest_cte_name, join_type=how.replace("_", " ")
+ )
+ join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
+ self_columns = self._get_outer_select_columns(join_expression)
+ other_columns = self._get_outer_select_columns(other_df)
+ # Determines the join clause and select columns to be used passed on what type of columns were provided for
+ # the join. The columns returned changes based on how the on expression is provided.
+ if isinstance(join_columns[0].expression, exp.Column):
+ """
+ Unique characteristics of join on column names only:
+ * The column names are put at the front of the select list
+ * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
+ """
+ table_names = [
+ table.alias_or_name
+ for table in get_tables_from_expression_with_join(join_expression)
]
+ potential_ctes = [
+ cte
+ for cte in join_expression.ctes
+ if cte.alias_or_name in table_names
+ and cte.alias_or_name != other_df.latest_cte_name
+ ]
+ # Determine the table to reference for the left side of the join by checking each of the left side
+ # tables and see if they have the column being referenced.
+ join_column_pairs = []
+ for join_column in join_columns:
+ num_matching_ctes = 0
+ for cte in potential_ctes:
+ if join_column.alias_or_name in cte.this.named_selects:
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
+ right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
+ join_column_pairs.append((left_column, right_column))
+ num_matching_ctes += 1
+ if num_matching_ctes > 1:
+ raise ValueError(
+ f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
+ )
+ elif num_matching_ctes == 0:
+ raise ValueError(
+ f"Column {join_column.alias_or_name} does not exist in any of the tables."
+ )
join_clause = functools.reduce(
lambda x, y: x & y,
- [
- col.copy().set_table_name(pre_join_self_latest_cte_name)
- == col.copy().set_table_name(other_df.latest_cte_name)
- for col in columns
- ],
+ [left_column == right_column for left_column, right_column in join_column_pairs],
)
- else:
- if len(columns) > 1:
- columns = [functools.reduce(lambda x, y: x & y, columns)]
- join_clause = columns[0]
- join_columns = [
- Column(x).set_table_name(pre_join_self_latest_cte_name)
- if i % 2 == 0
- else Column(x).set_table_name(other_df.latest_cte_name)
- for i, x in enumerate(join_clause.expression.find_all(exp.Column))
+ join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
+ # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
+ select_column_names = [
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql()
+ for column in self_columns + other_columns
]
- self_columns = [
- column.set_table_name(pre_join_self_latest_cte_name, copy=True)
- for column in self._get_outer_select_columns(self)
- ]
- other_columns = [
- column.set_table_name(other_df.latest_cte_name, copy=True)
- for column in self._get_outer_select_columns(other_df)
- ]
- column_value_mapping = {
- column.alias_or_name
- if not isinstance(column.expression.this, exp.Star)
- else column.sql(): column
- for column in other_columns + self_columns + join_columns
- }
- all_columns = [
- column_value_mapping[name]
- for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
- ]
- new_df = self.copy(
- expression=self.expression.join(
- other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
- )
- )
- new_df.expression = new_df._add_ctes_to_expression(
- new_df.expression, other_df.expression.ctes
- )
+ select_column_names = [
+ column_name
+ for column_name in select_column_names
+ if column_name not in join_column_names
+ ]
+ select_column_names = join_column_names + select_column_names
+ else:
+ """
+ Unique characteristics of join on expressions:
+ * There is no deduplication of the results.
+ * The left join dataframe columns go first and right come after. No sort preference is given to join columns
+ """
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
+ if len(join_columns) > 1:
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
+ join_clause = join_columns[0]
+ select_column_names = [column.alias_or_name for column in self_columns + other_columns]
+
+ # Update the on expression with the actual join clause to replace the dummy one from before
+ join_expression.args["joins"][-1].set("on", join_clause.expression)
+ new_df = self.copy(expression=join_expression)
+ new_df.pending_join_hints.extend(self.pending_join_hints)
new_df.pending_hints.extend(other_df.pending_hints)
- new_df = new_df.select.__wrapped__(new_df, *all_columns)
+ new_df = new_df.select.__wrapped__(new_df, *select_column_names)
return new_df
@operation(Operation.ORDER_BY)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index f77b4f8..993d869 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -577,11 +577,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
+ return Column.invoke_expression_over_column(
+ col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
+ )
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
+ return Column.invoke_expression_over_column(
+ col, expression.DateSub, expression=days, unit=expression.Var(this="day")
+ )
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
@@ -695,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column:
def md5(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_anonymous_function(column, "MD5")
+ return Column.invoke_expression_over_column(column, expression.MD5)
def sha1(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_anonymous_function(column, "SHA1")
+ return Column.invoke_expression_over_column(column, expression.SHA)
def sha2(col: ColumnOrName, numBits: int) -> Column:
column = col if isinstance(col, Column) else lit(col)
- num_bits = lit(numBits)
- return Column.invoke_anonymous_function(column, "SHA2", num_bits)
+ return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits))
def hash(*cols: ColumnOrName) -> Column:
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index febc664..cc2f181 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.helper import object_to_dict
+from sqlglot.helper import object_to_dict, should_identify
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
@@ -19,9 +19,17 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
+
return DataFrame(
self.spark,
- exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
+ exp.Select()
+ .from_(tableName)
+ .select(
+ *(
+ column if should_identify(column, "safe") else f'"{column}"'
+ for column in sqlglot.schema.column_names(tableName)
+ )
+ ),
)
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 701377b..1a88654 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
no_ilike_sql,
+ parse_date_delta_with_interval,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
@@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType
E = t.TypeVar("E", bound=exp.Expression)
-def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
- def func(args):
- interval = seq_get(args, 1)
- return expression_class(
- this=seq_get(args, 0),
- expression=interval.this,
- unit=interval.args.get("unit"),
- )
-
- return func
-
-
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
@@ -142,6 +131,7 @@ class BigQuery(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "ANY TYPE": TokenType.VARIANT,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
@@ -155,14 +145,19 @@ class BigQuery(Dialect):
KEYWORDS.pop("DIV")
class Parser(parser.Parser):
+ PREFIXED_PIVOT_COLUMNS = True
+
+ LOG_BASE_FIRST = False
+ LOG_DEFAULTS_TO_LN = True
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
this=seq_get(args, 0),
),
- "DATE_ADD": _date_add(exp.DateAdd),
- "DATETIME_ADD": _date_add(exp.DatetimeAdd),
+ "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
+ "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
@@ -174,12 +169,12 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
- "TIME_ADD": _date_add(exp.TimeAdd),
- "TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
- "DATE_SUB": _date_add(exp.DateSub),
- "DATETIME_SUB": _date_add(exp.DatetimeSub),
- "TIME_SUB": _date_add(exp.TimeSub),
- "TIMESTAMP_SUB": _date_add(exp.TimestampSub),
+ "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
+ "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
+ "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
+ "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
+ "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
),
@@ -209,14 +204,17 @@ class BigQuery(Dialect):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
"NOT DETERMINISTIC": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
+ exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
}
- LOG_BASE_FIRST = False
- LOG_DEFAULTS_TO_LN = True
-
class Generator(generator.Generator):
+ EXPLICIT_UNION = True
+ INTERVAL_ALLOWS_PLURAL_FORM = False
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+ LIMIT_FETCH = "LIMIT"
+
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
@@ -236,9 +234,7 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
- exp.Select: transforms.preprocess(
- [_unqualify_unnest], transforms.delegate("select_sql")
- ),
+ exp.Select: transforms.preprocess([_unqualify_unnest]),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@@ -253,7 +249,7 @@ class BigQuery(Dialect):
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
- exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
+ exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
@@ -261,6 +257,7 @@ class BigQuery(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
+ exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
@@ -272,17 +269,19 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.TIMESTAMP: "DATETIME",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARCHAR: "STRING",
+ exp.DataType.Type.VARIANT: "ANY TYPE",
}
+
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- EXPLICIT_UNION = True
- LIMIT_FETCH = "LIMIT"
-
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index b06462c..e91b0bf 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -144,6 +144,13 @@ class ClickHouse(Dialect):
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
+ JOIN_HINTS = False
+ TABLE_HINTS = False
EXPLICIT_UNION = True
def _param_args_sql(
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 2f93ee7..138f26c 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType
class Databricks(Spark):
class Parser(Spark.Parser):
+ LOG_DEFAULTS_TO_LN = True
+
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
"DATEADD": parse_date_delta(exp.DateAdd),
@@ -16,13 +18,17 @@ class Databricks(Spark):
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
- LOG_DEFAULTS_TO_LN = True
+ FACTOR = {
+ **Spark.Parser.FACTOR,
+ TokenType.COLON: exp.JSONExtract,
+ }
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
+ exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 839589d..19c6f73 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
return ""
+def no_comment_column_constraint_sql(
+ self: Generator, expression: exp.CommentColumnConstraint
+) -> str:
+ self.unsupported("CommentColumnConstraint unsupported")
+ return ""
+
+
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
@@ -379,15 +386,35 @@ def parse_date_delta(
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3
- this = seq_get(args, 2) if unit_based else seq_get(args, 0)
- expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
- unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
- unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
- return exp_class(this=this, expression=expression, unit=unit)
+ this = args[2] if unit_based else seq_get(args, 0)
+ unit = args[0] if unit_based else exp.Literal.string("DAY")
+ unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
+ return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
+def parse_date_delta_with_interval(
+ expression_class: t.Type[E],
+) -> t.Callable[[t.Sequence], t.Optional[E]]:
+ def func(args: t.Sequence) -> t.Optional[E]:
+ if len(args) < 2:
+ return None
+
+ interval = args[1]
+ expression = interval.this
+ if expression and expression.is_string:
+ expression = exp.Literal.number(expression.this)
+
+ return expression_class(
+ this=args[0],
+ expression=expression,
+ unit=exp.Literal.string(interval.text("unit")),
+ )
+
+ return func
+
+
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index a33aadc..d7e2d88 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -104,6 +104,9 @@ class Drill(Dialect):
LOG_DEFAULTS_TO_LN = True
class Generator(generator.Generator):
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
@@ -120,6 +123,7 @@ class Drill(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = {
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index c034208..9454db6 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
+ no_comment_column_constraint_sql,
no_pivot_sql,
no_properties_sql,
no_safe_divide_sql,
@@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add(self, expression):
- this = expression.args.get("this")
+ this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
@@ -139,6 +140,8 @@ class DuckDB(Dialect):
}
class Generator(generator.Generator):
+ JOIN_HINTS = False
+ TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
@@ -150,6 +153,7 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
+ exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@@ -213,6 +217,11 @@ class DuckDB(Dialect):
"except": "EXCLUDE",
}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
LIMIT_FETCH = "LIMIT"
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index c39656e..6746fcf 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
+def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
- modified_increment = (
- int(expression.text("expression")) * multiplier
- if expression.expression.is_number
- else expression.expression
- )
- modified_increment = exp.Literal.number(modified_increment)
- return self.func(func, expression.this, modified_increment.this)
+
+ if isinstance(expression, exp.DateSub):
+ multiplier *= -1
+
+ if expression.expression.is_number:
+ modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
+ else:
+ modified_increment = expression.expression
+ if multiplier != 1:
+ modified_increment = exp.Mul( # type: ignore
+ this=modified_increment, expression=exp.Literal.number(multiplier)
+ )
+
+ return self.func(func, expression.this, modified_increment)
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
@@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
-def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
- unnest = expression.this
- if isinstance(unnest, exp.Unnest):
- alias = unnest.args.get("alias")
- udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
- return "".join(
- self.sql(
- exp.Lateral(
- this=udtf(this=expression),
- view=True,
- alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
- )
- )
- for expression, column in zip(unnest.expressions, alias.columns if alias else [])
- )
- return self.join_sql(expression)
-
-
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
@@ -195,6 +184,7 @@ class Hive(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
+ IDENTIFIER_CAN_START_WITH_DIGIT = True
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -217,9 +207,8 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
- IDENTIFIER_CAN_START_WITH_DIGIT = True
-
class Parser(parser.Parser):
+ LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
FUNCTIONS = {
@@ -273,9 +262,13 @@ class Hive(Dialect):
),
}
- LOG_DEFAULTS_TO_LN = True
-
class Generator(generator.Generator):
+ LIMIT_FETCH = "LIMIT"
+ TABLESAMPLE_WITH_METHOD = False
+ TABLESAMPLE_SIZE_IS_PERCENT = True
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
@@ -289,6 +282,9 @@ class Hive(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
+ exp.Select: transforms.preprocess(
+ [transforms.eliminate_qualify, transforms.unnest_to_explode]
+ ),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
@@ -298,13 +294,13 @@ class Hive(Dialect):
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"),
+ exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
- exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
+ exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
- exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: rename_func("TO_JSON"),
@@ -354,10 +350,9 @@ class Hive(Dialect):
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- LIMIT_FETCH = "LIMIT"
-
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
@@ -378,4 +373,5 @@ class Hive(Dialect):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
+
return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index d64efbf..666e740 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
+ datestrtodate_sql,
+ format_time_lambda,
locate_to_strposition,
max_or_greatest,
min_or_least,
@@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
+ parse_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
)
@@ -76,18 +79,6 @@ def _trim_sql(self, expression):
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add(expression_class):
- def func(args):
- interval = seq_get(args, 1)
- return expression_class(
- this=seq_get(args, 0),
- expression=interval.this,
- unit=exp.Literal.string(interval.text("unit").lower()),
- )
-
- return func
-
-
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
@@ -115,6 +106,7 @@ class MySQL(Dialect):
"%k": "%-H",
"%l": "%-I",
"%T": "%H:%M:%S",
+ "%W": "%a",
}
class Tokenizer(tokens.Tokenizer):
@@ -127,12 +119,13 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "MEDIUMTEXT": TokenType.MEDIUMTEXT,
+ "CHARSET": TokenType.CHARACTER_SET,
+ "LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
- "LONGBLOB": TokenType.LONGBLOB,
- "START": TokenType.BEGIN,
+ "MEDIUMTEXT": TokenType.MEDIUMTEXT,
"SEPARATOR": TokenType.SEPARATOR,
+ "START": TokenType.BEGIN,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -186,14 +179,15 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
- "DATE_ADD": _date_add(exp.DateAdd),
- "DATE_SUB": _date_add(exp.DateSub),
- "STR_TO_DATE": _str_to_date,
- "LOCATE": locate_to_strposition,
+ "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
+ "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
+ "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
+ "LOCATE": locate_to_strposition,
+ "STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
@@ -388,32 +382,36 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
+ JOIN_HINTS = False
+ TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
- exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
- exp.ILike: no_ilike_sql,
- exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
- exp.Max: max_or_greatest,
- exp.Min: min_or_least,
- exp.TableSample: no_tablesample_sql,
- exp.TryCast: no_trycast_sql,
+ exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
- exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
- exp.StrToDate: _str_to_date_sql,
- exp.StrToTime: _str_to_date_sql,
- exp.Trim: _trim_sql,
+ exp.ILike: no_ilike_sql,
+ exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.Max: max_or_greatest,
+ exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_locate_sql,
+ exp.StrToDate: _str_to_date_sql,
+ exp.StrToTime: _str_to_date_sql,
+ exp.TableSample: no_tablesample_sql,
+ exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
+ exp.Trim: _trim_sql,
+ exp.TryCast: no_trycast_sql,
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
@@ -425,6 +423,7 @@ class MySQL(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "LIMIT"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 3819b76..9ccd02e 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
- TokenType.COLUMN,
- TokenType.RETURNING,
-}
-
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable:
if self._match_text_seq("PASSING"):
# The BY VALUE keywords are optional and are provided for semantic clarity
self._match_text_seq("BY", "VALUE")
- passing = self._parse_csv(
- lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
- )
+ passing = self._parse_csv(self._parse_column)
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
@@ -68,6 +61,8 @@ class Oracle(Dialect):
}
class Parser(parser.Parser):
+ WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
@@ -78,6 +73,12 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
+ TYPE_LITERAL_PARSERS = {
+ exp.DataType.Type.DATE: lambda self, this, _: self.expression(
+ exp.DateStrToDate, this=this
+ )
+ }
+
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
@@ -100,6 +101,8 @@ class Oracle(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
+ JOIN_HINTS = False
+ TABLE_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -119,6 +122,9 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
+ exp.DateStrToDate: lambda self, e: self.func(
+ "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
+ ),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@@ -129,6 +135,12 @@ class Oracle(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
+ exp.IfNull: rename_func("NVL"),
+ }
+
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
LIMIT_FETCH = "FETCH"
@@ -142,9 +154,9 @@ class Oracle(Dialect):
def xmltable_sql(self, expression: exp.XMLTable) -> str:
this = self.sql(expression, "this")
- passing = self.expressions(expression, "passing")
+ passing = self.expressions(expression, key="passing")
passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
- columns = self.expressions(expression, "columns")
+ columns = self.expressions(expression, key="columns")
columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
by_ref = (
f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 31b7e45..c47ff51 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
+ datestrtodate_sql,
format_time_lambda,
max_or_greatest,
min_or_least,
@@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
-from sqlglot.transforms import delegate, preprocess
+from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@@ -239,7 +240,6 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
- "UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
}
@@ -248,18 +248,25 @@ class Postgres(Dialect):
"$": TokenType.PARAMETER,
}
+ VAR_SINGLE_TOKENS = {"$"}
+
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
- "NOW": exp.CurrentTimestamp.from_arg_list,
- "TO_TIMESTAMP": _to_timestamp,
- "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
- "GENERATE_SERIES": _generate_series,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
+ "GENERATE_SERIES": _generate_series,
+ "NOW": exp.CurrentTimestamp.from_arg_list,
+ "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
+ "TO_TIMESTAMP": _to_timestamp,
+ }
+
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS,
+ "DATE_PART": lambda self: self._parse_date_part(),
}
BITWISE = {
@@ -279,8 +286,21 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
+ def _parse_date_part(self) -> exp.Expression:
+ part = self._parse_type()
+ self._match(TokenType.COMMA)
+ value = self._parse_bitwise()
+
+ if part and part.is_string:
+ part = exp.Var(this=part.name)
+
+ return self.expression(exp.Extract, this=part, expression=value)
+
class Generator(generator.Generator):
+ INTERVAL_ALLOWS_PLURAL_FORM = False
LOCKING_READS_SUPPORTED = True
+ JOIN_HINTS = False
+ TABLE_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@@ -301,7 +321,6 @@ class Postgres(Dialect):
_auto_increment_to_serial,
_serial_to_generated,
],
- delegate("columndef_sql"),
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@@ -312,6 +331,7 @@ class Postgres(Dialect):
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
@@ -321,6 +341,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
+ exp.Merge: preprocess([remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
@@ -344,4 +365,5 @@ class Postgres(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 07e8f43..489d439 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _approx_distinct_sql(self, expression):
+def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
sql = f"{sql} WITH TIME ZONE"
return sql
-def _explode_to_unnest_sql(self, expression):
+def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
exp.Join(
@@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
return self.lateral_sql(expression)
-def _initcap_sql(self, expression):
+def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
-def _decode_sql(self, expression):
- _ensure_utf8(expression.args.get("charset"))
+def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
+ _ensure_utf8(expression.args["charset"])
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
-def _encode_sql(self, expression):
- _ensure_utf8(expression.args.get("charset"))
+def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
+ _ensure_utf8(expression.args["charset"])
return f"TO_UTF8({self.sql(expression, 'this')})"
-def _no_sort_array(self, expression):
+def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
return self.func("ARRAY_SORT", expression.this, comparator)
-def _schema_sql(self, expression):
+def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
- for schema in expression.parent.find_all(exp.Schema):
- if isinstance(schema.parent, exp.Property):
- expression = expression.copy()
- expression.expressions.extend(schema.expressions)
+ if expression.parent:
+ for schema in expression.parent.find_all(exp.Schema):
+ if isinstance(schema.parent, exp.Property):
+ expression = expression.copy()
+ expression.expressions.extend(schema.expressions)
return self.schema_sql(expression)
-def _quantile_sql(self, expression):
+def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
-def _str_to_time_sql(self, expression):
+def _str_to_time_sql(
+ self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
-def _ts_or_ds_to_date_sql(self, expression):
+def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
-def _ts_or_ds_add_sql(self, expression):
+def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
+ this = expression.this
+
+ if not isinstance(this, exp.CurrentDate):
+ this = self.func(
+ "DATE_PARSE",
+ self.func(
+ "SUBSTR",
+ this if this.is_string else exp.cast(this, "VARCHAR"),
+ exp.Literal.number(1),
+ exp.Literal.number(10),
+ ),
+ Presto.date_format,
+ )
+
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
- self.func(
- "DATE_PARSE",
- self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
- Presto.date_format,
- ),
+ this,
)
-def _sequence_sql(self, expression):
+def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
@@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
return self.func("SEQUENCE", start, end, step)
-def _ensure_utf8(charset):
+def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
-def _approx_percentile(args):
+def _approx_percentile(args: t.Sequence) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@@ -157,7 +172,7 @@ def _approx_percentile(args):
return exp.ApproxQuantile.from_arg_list(args)
-def _from_unixtime(args):
+def _from_unixtime(args: t.Sequence) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@@ -226,11 +241,15 @@ class Presto(Dialect):
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
+ INTERVAL_ALLOWS_PLURAL_FORM = False
+ JOIN_HINTS = False
+ TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
@@ -246,7 +265,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
- **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@@ -284,6 +302,9 @@ class Presto(Dialect):
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
+ exp.Select: transforms.preprocess(
+ [transforms.eliminate_qualify, transforms.explode_to_unnest]
+ ),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
@@ -308,7 +329,13 @@ class Presto(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
- def transaction_sql(self, expression):
+ def interval_sql(self, expression: exp.Interval) -> str:
+ unit = self.sql(expression, "unit")
+ if expression.this and unit.lower().startswith("week"):
+ return f"({expression.this.name} * INTERVAL '7' day)"
+ return super().interval_sql(expression)
+
+ def transaction_sql(self, expression: exp.Transaction) -> str:
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 63c14f4..a9c4f62 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -8,6 +8,10 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
+def _json_sql(self, e) -> str:
+ return f'{self.sql(e, "this")}."{e.expression.name}"'
+
+
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
@@ -56,6 +60,7 @@ class Redshift(Postgres):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
+ "SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
@@ -63,7 +68,14 @@ class Redshift(Postgres):
"VARBYTE": TokenType.VARBINARY,
}
+ # Redshift allows # to appear as a table identifier prefix
+ SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy()
+ SINGLE_TOKENS.pop("#")
+
class Generator(Postgres.Generator):
+ LOCKING_READS_SUPPORTED = False
+ SINGLE_STRING_INTERVAL = True
+
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
@@ -79,6 +91,7 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
+ exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
),
@@ -87,12 +100,16 @@ class Redshift(Postgres):
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
+ exp.JSONExtract: _json_sql,
+ exp.JSONExtractScalar: _json_sql,
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
+ RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"}
+
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 34bc3bd..0829669 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
-def _check_int(s):
+def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
return s[1:].isdigit()
return s.isdigit()
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _snowflake_to_timestamp(args):
+def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
-def _unix_to_time_sql(self, expression):
+def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression):
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self):
+def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
+
+ if not this:
+ return None
+
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
@@ -101,7 +105,7 @@ def _parse_date_part(self):
scale = None
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
- to_unix = self.expression(exp.TimeToUnix, this=ts)
+ to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
if scale:
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
@@ -112,7 +116,7 @@ def _parse_date_part(self):
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args):
+def _div0_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -120,18 +124,18 @@ def _div0_to_if(args):
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args):
+def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args):
+def _nullifzero_to_if(args: t.Sequence) -> exp.Expression:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
@@ -155,9 +159,8 @@ class Snowflake(Dialect):
"MM": "%m",
"mm": "%m",
"DD": "%d",
- "dd": "%d",
- "d": "%-d",
- "DY": "%w",
+ "dd": "%-d",
+ "DY": "%a",
"dy": "%w",
"HH24": "%H",
"hh24": "%H",
@@ -174,6 +177,8 @@ class Snowflake(Dialect):
}
class Parser(parser.Parser):
+ QUOTED_PIVOT_COLUMNS = True
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
@@ -269,9 +274,14 @@ class Snowflake(Dialect):
"$": TokenType.PARAMETER,
}
+ VAR_SINGLE_TOKENS = {"$"}
+
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
+ SINGLE_STRING_INTERVAL = True
+ JOIN_HINTS = False
+ TABLE_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@@ -287,26 +297,30 @@ class Snowflake(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.If: rename_func("IFF"),
- exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
- exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
- exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.LogicalOr: rename_func("BOOLOR_AGG"),
+ exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.Max: max_or_greatest,
+ exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
- exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
+ exp.TimeToStr: lambda self, e: self.func(
+ "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
+ ),
+ exp.TimestampTrunc: timestamptrunc_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.Max: max_or_greatest,
- exp.Min: min_or_least,
+ exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
}
TYPE_MAPPING = {
@@ -322,14 +336,15 @@ class Snowflake(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index c271f6f..a3e4cce 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -1,13 +1,15 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self, e):
- kind = e.args.get("kind")
+def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
+ kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
@@ -18,13 +20,13 @@ def _create_sql(self, e):
return create_with_partitions_sql(self, e)
-def _map_sql(self, expression):
+def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
-def _str_to_date(self, expression):
+def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
@@ -32,7 +34,7 @@ def _str_to_date(self, expression):
return f"TO_DATE({this}, {time_format})"
-def _unix_to_time(self, expression):
+def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@@ -75,7 +77,11 @@ class Spark(Hive):
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "BOOLEAN": lambda args: exp.Cast(
+ this=seq_get(args, 0), to=exp.DataType.build("boolean")
+ ),
"IIF": exp.If.from_arg_list,
+ "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@@ -89,11 +95,16 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
+ "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
+ "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "TIMESTAMP": lambda args: exp.Cast(
+ this=seq_get(args, 0), to=exp.DataType.build("timestamp")
+ ),
}
FUNCTION_PARSERS = {
@@ -108,16 +119,43 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
- def _parse_add_column(self):
+ def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
- def _parse_drop_column(self):
+ def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
+ def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
+ # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
+ if len(pivot_columns) == 1:
+ return [""]
+
+ names = []
+ for agg in pivot_columns:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+
+ Moreover, function names are lowercased in order to mimic Spark's naming scheme.
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
+
+ return names
+
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
@@ -145,7 +183,7 @@ class Spark(Hive):
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time,
+ exp.UnixToTime: _unix_to_time_sql,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 4091dbb..4437f82 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType
def _date_add_sql(self, expression):
modifier = expression.expression
- modifier = expression.name if modifier.is_string else self.sql(modifier)
+ modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return self.func("DATE", expression.this, modifier)
@@ -38,6 +38,9 @@ class SQLite(Dialect):
}
class Generator(generator.Generator):
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER",
@@ -82,6 +85,11 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 2ba1a92..ff19dab 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -1,7 +1,11 @@
from __future__ import annotations
from sqlglot import exp
-from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
+from sqlglot.dialects.dialect import (
+ approx_count_distinct_sql,
+ arrow_json_extract_sql,
+ rename_func,
+)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@@ -10,6 +14,7 @@ class StarRocks(MySQL):
class Parser(MySQL.Parser): # type: ignore
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
+ "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
@@ -25,6 +30,7 @@ class StarRocks(MySQL):
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS, # type: ignore
+ exp.ApproxDistinct: approx_count_distinct_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 31b1c8d..792c2b4 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -21,6 +21,9 @@ def _count_sql(self, expression):
class Tableau(Dialect):
class Generator(generator.Generator):
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
@@ -28,6 +31,11 @@ class Tableau(Dialect):
exp.Count: _count_sql,
}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 3d43793..331e105 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -1,7 +1,14 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens
-from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
+from sqlglot.dialects.dialect import (
+ Dialect,
+ format_time_lambda,
+ max_or_greatest,
+ min_or_least,
+)
from sqlglot.tokens import TokenType
@@ -115,7 +122,18 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
+ def _parse_cast(self, strict: bool) -> exp.Expression:
+ cast = t.cast(exp.Cast, super()._parse_cast(strict))
+ if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
+ return format_time_lambda(exp.TimeToStr, "teradata")(
+ [cast.this, self._parse_string()]
+ )
+ return cast
+
class Generator(generator.Generator):
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
@@ -130,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b8a227b..9cf56e1 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -96,6 +96,23 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
+def _parse_hashbytes(args):
+ kind, data = args
+ kind = kind.name.upper() if kind.is_string else ""
+
+ if kind == "MD5":
+ args.pop(0)
+ return exp.MD5(this=data)
+ if kind in ("SHA", "SHA1"):
+ args.pop(0)
+ return exp.SHA(this=data)
+ if kind == "SHA2_256":
+ return exp.SHA2(this=data, length=exp.Literal.number(256))
+ if kind == "SHA2_512":
+ return exp.SHA2(this=data, length=exp.Literal.number(512))
+ return exp.func("HASHBYTES", *args)
+
+
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return self.func(func, e.text("unit"), e.expression, e.this)
@@ -266,6 +283,7 @@ class TSQL(Dialect):
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
+ "SYSTEM_USER": TokenType.CURRENT_USER,
}
# TSQL allows @, # to appear as a variable/identifier prefix
@@ -287,6 +305,7 @@ class TSQL(Dialect):
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
+ "HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
@@ -296,6 +315,14 @@ class TSQL(Dialect):
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
+ "SYSTEM_USER": exp.CurrentUser.from_arg_list,
+ }
+
+ JOIN_HINTS = {
+ "LOOP",
+ "HASH",
+ "MERGE",
+ "REMOTE",
}
VAR_LENGTH_DATATYPES = {
@@ -441,11 +468,21 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.Max: max_or_greatest,
+ exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
+ exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
+ exp.SHA2: lambda self, e: self.func(
+ "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
+ ),
}
TRANSFORMS.pop(exp.ReturnsProperty)
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str:
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 9011dce..49d3ff6 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -701,6 +701,119 @@ class Condition(Expression):
"""
return not_(self)
+ def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
+ this = self
+ other = convert(other)
+ if not isinstance(this, klass) and not isinstance(other, klass):
+ this = _wrap(this, Binary)
+ other = _wrap(other, Binary)
+ if reverse:
+ return klass(this=other, expression=this)
+ return klass(this=this, expression=other)
+
+ def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
+ if isinstance(other, slice):
+ return Between(
+ this=self,
+ low=convert(other.start),
+ high=convert(other.stop),
+ )
+ return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
+
+ def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
+ return In(
+ this=self,
+ expressions=[convert(e) for e in expressions],
+ query=maybe_parse(query, **opts) if query else None,
+ )
+
+ def like(self, other: ExpOrStr) -> Like:
+ return self._binop(Like, other)
+
+ def ilike(self, other: ExpOrStr) -> ILike:
+ return self._binop(ILike, other)
+
+ def eq(self, other: ExpOrStr) -> EQ:
+ return self._binop(EQ, other)
+
+ def neq(self, other: ExpOrStr) -> NEQ:
+ return self._binop(NEQ, other)
+
+ def rlike(self, other: ExpOrStr) -> RegexpLike:
+ return self._binop(RegexpLike, other)
+
+ def __lt__(self, other: ExpOrStr) -> LT:
+ return self._binop(LT, other)
+
+ def __le__(self, other: ExpOrStr) -> LTE:
+ return self._binop(LTE, other)
+
+ def __gt__(self, other: ExpOrStr) -> GT:
+ return self._binop(GT, other)
+
+ def __ge__(self, other: ExpOrStr) -> GTE:
+ return self._binop(GTE, other)
+
+ def __add__(self, other: ExpOrStr) -> Add:
+ return self._binop(Add, other)
+
+ def __radd__(self, other: ExpOrStr) -> Add:
+ return self._binop(Add, other, reverse=True)
+
+ def __sub__(self, other: ExpOrStr) -> Sub:
+ return self._binop(Sub, other)
+
+ def __rsub__(self, other: ExpOrStr) -> Sub:
+ return self._binop(Sub, other, reverse=True)
+
+ def __mul__(self, other: ExpOrStr) -> Mul:
+ return self._binop(Mul, other)
+
+ def __rmul__(self, other: ExpOrStr) -> Mul:
+ return self._binop(Mul, other, reverse=True)
+
+ def __truediv__(self, other: ExpOrStr) -> Div:
+ return self._binop(Div, other)
+
+ def __rtruediv__(self, other: ExpOrStr) -> Div:
+ return self._binop(Div, other, reverse=True)
+
+ def __floordiv__(self, other: ExpOrStr) -> IntDiv:
+ return self._binop(IntDiv, other)
+
+ def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
+ return self._binop(IntDiv, other, reverse=True)
+
+ def __mod__(self, other: ExpOrStr) -> Mod:
+ return self._binop(Mod, other)
+
+ def __rmod__(self, other: ExpOrStr) -> Mod:
+ return self._binop(Mod, other, reverse=True)
+
+ def __pow__(self, other: ExpOrStr) -> Pow:
+ return self._binop(Pow, other)
+
+ def __rpow__(self, other: ExpOrStr) -> Pow:
+ return self._binop(Pow, other, reverse=True)
+
+ def __and__(self, other: ExpOrStr) -> And:
+ return self._binop(And, other)
+
+ def __rand__(self, other: ExpOrStr) -> And:
+ return self._binop(And, other, reverse=True)
+
+ def __or__(self, other: ExpOrStr) -> Or:
+ return self._binop(Or, other)
+
+ def __ror__(self, other: ExpOrStr) -> Or:
+ return self._binop(Or, other, reverse=True)
+
+ def __neg__(self) -> Neg:
+ return Neg(this=_wrap(self, Binary))
+
+ def __invert__(self) -> Not:
+ return not_(self)
+
class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""
@@ -818,7 +931,6 @@ class Create(Expression):
"properties": False,
"replace": False,
"unique": False,
- "volatile": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
+# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html
+class OnUpdateColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
@@ -1197,6 +1314,7 @@ class Drop(Expression):
"materialized": False,
"cascade": False,
"constraints": False,
+ "purge": False,
}
@@ -1287,6 +1405,7 @@ class Insert(Expression):
"with": False,
"this": True,
"expression": False,
+ "conflict": False,
"returning": False,
"overwrite": False,
"exists": False,
@@ -1295,6 +1414,16 @@ class Insert(Expression):
}
+class OnConflict(Expression):
+ arg_types = {
+ "duplicate": False,
+ "expressions": False,
+ "nothing": False,
+ "key": False,
+ "constraint": False,
+ }
+
+
class Returning(Expression):
arg_types = {"expressions": True}
@@ -1326,7 +1455,12 @@ class Partition(Expression):
class Fetch(Expression):
- arg_types = {"direction": False, "count": False}
+ arg_types = {
+ "direction": False,
+ "count": False,
+ "percent": False,
+ "with_ties": False,
+ }
class Group(Expression):
@@ -1374,6 +1508,7 @@ class Join(Expression):
"kind": False,
"using": False,
"natural": False,
+ "hint": False,
}
@property
@@ -1385,6 +1520,10 @@ class Join(Expression):
return self.text("side").upper()
@property
+ def hint(self):
+ return self.text("hint").upper()
+
+ @property
def alias_or_name(self):
return self.this.alias_or_name
@@ -1475,6 +1614,7 @@ class MatchRecognize(Expression):
"after": False,
"pattern": False,
"define": False,
+ "alias": False,
}
@@ -1582,6 +1722,10 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
+class InputOutputFormat(Expression):
+ arg_types = {"input_format": False, "output_format": False}
+
+
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
@@ -1646,6 +1790,10 @@ class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
+class RowFormatProperty(Property):
+ arg_types = {"this": True}
+
+
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
@@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property):
arg_types = {"definer": True}
+class StabilityProperty(Property):
+ arg_types = {"this": True}
+
+
class TableFormatProperty(Property):
arg_types = {"this": True}
@@ -1695,8 +1847,8 @@ class TransientProperty(Property):
arg_types = {"this": False}
-class VolatilityProperty(Property):
- arg_types = {"this": True}
+class VolatileProperty(Property):
+ arg_types = {"this": False}
class WithDataProperty(Property):
@@ -1726,6 +1878,7 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
+ "ROW_FORMAT": RowFormatProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@@ -2721,6 +2874,7 @@ class Pivot(Expression):
"expressions": True,
"field": True,
"unpivot": True,
+ "columns": False,
}
@@ -2731,6 +2885,8 @@ class Window(Expression):
"order": False,
"spec": False,
"alias": False,
+ "over": False,
+ "first": False,
}
@@ -2816,6 +2972,7 @@ class DataType(Expression):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
+ BIGDECIMAL = auto()
BIT = auto()
BOOLEAN = auto()
JSON = auto()
@@ -2964,7 +3121,7 @@ class DropPartition(Expression):
# Binary expressions like (ADD a b)
-class Binary(Expression):
+class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
@@ -2980,7 +3137,7 @@ class Add(Binary):
pass
-class Connector(Binary, Condition):
+class Connector(Binary):
pass
@@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary):
# Unary Expressions
# (NOT a)
-class Unary(Expression):
+class Unary(Condition):
pass
@@ -3150,11 +3307,11 @@ class BitwiseNot(Unary):
pass
-class Not(Unary, Condition):
+class Not(Unary):
pass
-class Paren(Unary, Condition):
+class Paren(Unary):
arg_types = {"this": True, "with": False}
@@ -3162,7 +3319,6 @@ class Neg(Unary):
pass
-# Special Functions
class Alias(Expression):
arg_types = {"this": True, "alias": False}
@@ -3381,6 +3537,16 @@ class AnyValue(AggFunc):
class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
+ def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
+ this = self.copy() if copy else self
+ this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
+ return this
+
+ def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
+ this = self.copy() if copy else self
+ this.set("default", maybe_parse(condition, **opts))
+ return this
+
class Cast(Func):
arg_types = {"this": True, "to": True}
@@ -3719,6 +3885,10 @@ class Map(Func):
arg_types = {"keys": False, "values": False}
+class StarMap(Func):
+ pass
+
+
class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
@@ -3734,6 +3904,10 @@ class Max(AggFunc):
is_var_len_args = True
+class MD5(Func):
+ _sql_names = ["MD5"]
+
+
class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -3840,6 +4014,15 @@ class SetAgg(AggFunc):
pass
+class SHA(Func):
+ _sql_names = ["SHA", "SHA1"]
+
+
+class SHA2(Func):
+ _sql_names = ["SHA2"]
+ arg_types = {"this": True, "length": False}
+
+
class SortArray(Func):
arg_types = {"this": True, "asc": False}
@@ -4017,6 +4200,12 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
+# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
+class NextValueFor(Func):
+ arg_types = {"this": True, "order": False}
+
+
def _norm_arg(arg):
return arg.lower() if type(arg) is str else arg
@@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
+@t.overload
+def maybe_parse(
+ sql_or_expression: ExpOrStr,
+ *,
+ into: t.Type[E],
+ dialect: DialectType = None,
+ prefix: t.Optional[str] = None,
+ copy: bool = False,
+ **opts,
+) -> E:
+ ...
+
+
+@t.overload
+def maybe_parse(
+ sql_or_expression: str | E,
+ *,
+ into: t.Optional[IntoType] = None,
+ dialect: DialectType = None,
+ prefix: t.Optional[str] = None,
+ copy: bool = False,
+ **opts,
+) -> E:
+ ...
+
+
def maybe_parse(
sql_or_expression: ExpOrStr,
*,
@@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
- this = _wrap_operator(this)
+ this = _wrap(this, Connector)
for expression in expressions[1:]:
- this = operator(this=this, expression=_wrap_operator(expression))
+ this = operator(this=this, expression=_wrap(expression, Connector))
return this
-def _wrap_operator(expression):
- if isinstance(expression, (And, Or, Not)):
- expression = Paren(this=expression)
+def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
+ if isinstance(expression, kind):
+ return Paren(this=expression)
return expression
@@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not:
dialect=dialect,
**opts,
)
- return Not(this=_wrap_operator(this))
+ return Not(this=_wrap(this, Connector))
def paren(expression) -> Paren:
@@ -4657,6 +4872,8 @@ def alias_(
if table:
table_alias = TableAlias(this=alias)
+
+ exp = exp.copy() if isinstance(expression, Expression) else exp
exp.set("alias", table_alias)
if not isinstance(table, bool):
@@ -4864,16 +5081,22 @@ def convert(value) -> Expression:
"""
if isinstance(value, Expression):
return value
- if value is None:
- return NULL
- if isinstance(value, bool):
- return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
- if isinstance(value, float) and math.isnan(value):
+ if isinstance(value, bool):
+ return Boolean(this=value)
+ if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
+ if isinstance(value, datetime.datetime):
+ datetime_literal = Literal.string(
+ (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
+ )
+ return TimeStrToTime(this=datetime_literal)
+ if isinstance(value, datetime.date):
+ date_literal = Literal.string(value.strftime("%Y-%m-%d"))
+ return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
@@ -4883,14 +5106,6 @@ def convert(value) -> Expression:
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
- if isinstance(value, datetime.datetime):
- datetime_literal = Literal.string(
- (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
- )
- return TimeStrToTime(this=datetime_literal)
- if isinstance(value, datetime.date):
- date_literal = Literal.string(value.strftime("%Y-%m-%d"))
- return DateStrToDate(this=date_literal)
raise ValueError(f"Cannot convert {value}")
@@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
-def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
+def expand(
+ expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
+) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
Examples:
@@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
+ >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql()
+ 'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */'
+
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
@@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
if source:
subquery = source.subquery(node.alias or name)
subquery.comments = [f"source: {name}"]
- return subquery
+ return subquery.transform(_expand, copy=False)
return node
return expression.transform(_expand, copy=copy)
@@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
from sqlglot.dialects.dialect import Dialect
- converted = [convert(arg) for arg in args]
- kwargs = {key: convert(value) for key, value in kwargs.items()}
+ converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
+ kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
parser = Dialect.get_or_raise(dialect)().parser()
from_args_list = parser.FUNCTIONS.get(name.upper())
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 8a49d55..bd12d54 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -76,11 +76,13 @@ class Generator:
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
- exp.VolatilityProperty: lambda self, e: e.name,
+ exp.StabilityProperty: lambda self, e: e.name,
+ exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
+ exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
@@ -110,8 +112,19 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
- # Whether or not limit and fetch are supported
- # "ALL", "LIMIT", "FETCH"
+ # Whether or not the INTERVAL expression works only with values like '1 day'
+ SINGLE_STRING_INTERVAL = False
+
+ # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
+ INTERVAL_ALLOWS_PLURAL_FORM = True
+
+ # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
+ TABLESAMPLE_WITH_METHOD = True
+
+ # Whether or not to treat the number in TABLESAMPLE (50) as a percentage
+ TABLESAMPLE_SIZE_IS_PERCENT = False
+
+ # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
TYPE_MAPPING = {
@@ -129,6 +142,18 @@ class Generator:
"replace": "REPLACE",
}
+ TIME_PART_SINGULARS = {
+ "microseconds": "microsecond",
+ "seconds": "second",
+ "minutes": "minute",
+ "hours": "hour",
+ "days": "day",
+ "weeks": "week",
+ "months": "month",
+ "quarters": "quarter",
+ "years": "year",
+ }
+
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@@ -168,6 +193,7 @@ class Generator:
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
@@ -175,15 +201,22 @@ class Generator:
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
+ exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
- exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
- WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
+ JOIN_HINTS = True
+ TABLE_HINTS = True
+
+ RESERVED_KEYWORDS: t.Set[str] = set()
+ WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
+ UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
+
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
@@ -322,10 +355,15 @@ class Generator:
comment = comment + " " if comment[-1].strip() else comment
return comment
- def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
- comments = expression.comments if self._comments else None
+ def maybe_comment(
+ self,
+ sql: str,
+ expression: t.Optional[exp.Expression] = None,
+ comments: t.Optional[t.List[str]] = None,
+ ) -> str:
+ comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
- if not comments:
+ if not comments or isinstance(expression, exp.Binary):
return sql
sep = "\n" if self.pretty else " "
@@ -621,7 +659,6 @@ class Generator:
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
- volatile = " VOLATILE" if expression.args.get("volatile") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
@@ -632,7 +669,7 @@ class Generator:
wrapped=False,
)
- modifiers = "".join((replace, unique, volatile, postcreate_props_sql))
+ modifiers = "".join((replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
@@ -684,6 +721,9 @@ class Generator:
def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
+ def bytestring_sql(self, expression: exp.ByteString) -> str:
+ return self.sql(expression, "this")
+
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
@@ -695,9 +735,7 @@ class Generator:
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
- values = (
- f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
- )
+ values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
else:
nested = f"({interior})"
@@ -713,7 +751,7 @@ class Generator:
this = self.sql(expression, "this")
this = f" FROM {this}" if this else ""
using_sql = (
- f" USING {self.expressions(expression, 'using', sep=', USING ')}"
+ f" USING {self.expressions(expression, key='using', sep=', USING ')}"
if expression.args.get("using")
else ""
)
@@ -730,7 +768,10 @@ class Generator:
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
+ purge = " PURGE" if expression.args.get("purge") else ""
+ return (
+ f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
+ )
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@@ -746,7 +787,10 @@ class Generator:
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
- return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
+ if expression.args.get("percent"):
+ count = f"{count} PERCENT"
+ with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY"
+ return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}"
def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
@@ -766,12 +810,24 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
- text = text.lower() if self.normalize and not expression.quoted else text
+ lower = text.lower()
+ text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
- if expression.quoted or should_identify(text, self.identify):
+ if (
+ expression.quoted
+ or should_identify(text, self.identify)
+ or lower in self.RESERVED_KEYWORDS
+ ):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
+ def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
+ input_format = self.sql(expression, "input_format")
+ input_format = f"INPUTFORMAT {input_format}" if input_format else ""
+ output_format = self.sql(expression, "output_format")
+ output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
+ return self.sep().join((input_format, output_format))
+
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
@@ -984,9 +1040,10 @@ class Generator:
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
+ conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
sep = self.sep() if partition_sql else ""
- sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}"
+ sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@@ -1004,6 +1061,19 @@ class Generator:
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
+ def onconflict_sql(self, expression: exp.OnConflict) -> str:
+ conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
+ constraint = self.sql(expression, "constraint")
+ if constraint:
+ constraint = f"ON CONSTRAINT {constraint}"
+ key = self.expressions(expression, key="key", flat=True)
+ do = "" if expression.args.get("duplicate") else " DO "
+ nothing = "NOTHING" if expression.args.get("nothing") else ""
+ expressions = self.expressions(expression, flat=True)
+ if expressions:
+ expressions = f"UPDATE SET {expressions}"
+ return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
+
def returning_sql(self, expression: exp.Returning) -> str:
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
@@ -1036,7 +1106,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
- hints = f" WITH ({hints})" if hints else ""
+ hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
@@ -1053,7 +1123,7 @@ class Generator:
this = self.sql(expression, "this")
alias = ""
method = self.sql(expression, "method")
- method = f"{method.upper()} " if method else ""
+ method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
denominator = self.sql(expression, "bucket_denominator")
field = self.sql(expression, "bucket_field")
@@ -1064,6 +1134,8 @@ class Generator:
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size")
+ if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
+ size = f"{size} PERCENT"
seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
@@ -1154,6 +1226,7 @@ class Generator:
"NATURAL" if expression.args.get("natural") else None,
expression.side,
expression.kind,
+ expression.hint if self.JOIN_HINTS else None,
"JOIN",
)
if op
@@ -1311,16 +1384,20 @@ class Generator:
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
order = self.sql(expression, "order")
- measures = self.sql(expression, "measures")
- measures = self.seg(f"MEASURES {measures}") if measures else ""
+ measures = self.expressions(expression, key="measures")
+ measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else ""
rows = self.sql(expression, "rows")
rows = self.seg(rows) if rows else ""
after = self.sql(expression, "after")
after = self.seg(after) if after else ""
pattern = self.sql(expression, "pattern")
pattern = self.seg(f"PATTERN ({pattern})") if pattern else ""
- define = self.sql(expression, "define")
- define = self.seg(f"DEFINE {define}") if define else ""
+ definition_sqls = [
+ f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}"
+ for definition in expression.args.get("define", [])
+ ]
+ definitions = self.expressions(sqls=definition_sqls)
+ define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else ""
body = "".join(
(
partition,
@@ -1332,7 +1409,9 @@ class Generator:
define,
)
)
- return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
+ alias = self.sql(expression, "alias")
+ alias = f" {alias}" if alias else ""
+ return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit = expression.args.get("limit")
@@ -1353,7 +1432,7 @@ class Generator:
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
- self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
+ self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
@@ -1471,15 +1550,21 @@ class Generator:
partition_sql = partition + " " if partition and order else partition
spec = expression.args.get("spec")
- spec_sql = " " + self.window_spec_sql(spec) if spec else ""
+ spec_sql = " " + self.windowspec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
- this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
+ over = self.sql(expression, "over") or "OVER"
+ this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
+
+ first = expression.args.get("first")
+ if first is not None:
+ first = " FIRST " if first else " LAST "
+ first = first or ""
if not partition and not order and not spec and alias:
return f"{this} {alias}"
- window_args = alias + partition_sql + order_sql + spec_sql
+ window_args = alias + first + partition_sql + order_sql + spec_sql
return f"{this} ({window_args.strip()})"
@@ -1487,7 +1572,7 @@ class Generator:
partition = self.expressions(expression, key="partition_by", flat=True)
return f"PARTITION BY {partition}" if partition else ""
- def window_spec_sql(self, expression: exp.WindowSpec) -> str:
+ def windowspec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
@@ -1508,7 +1593,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
- expressions = apply_index_offset(expression.expressions, self.index_offset)
+ expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@@ -1550,6 +1635,11 @@ class Generator:
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
+ def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str:
+ order = expression.args.get("order")
+ order = f" OVER ({self.order_sql(order, flat=True)})" if order else ""
+ return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
+
def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
@@ -1586,7 +1676,7 @@ class Generator:
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
- options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
@@ -1644,17 +1734,20 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
- this = expression.args.get("this")
- if this:
- this = (
- f" {this}"
- if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
- else f" ({this})"
- )
- else:
- this = ""
unit = self.sql(expression, "unit")
+ if not self.INTERVAL_ALLOWS_PLURAL_FORM:
+ unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
unit = f" {unit}" if unit else ""
+
+ if self.SINGLE_STRING_INTERVAL:
+ this = expression.this.name if expression.this else ""
+ return f"INTERVAL '{this}{unit}'"
+
+ this = self.sql(expression, "this")
+ if this:
+ unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES)
+ this = f" {this}" if unwrapped else f" ({this})"
+
return f"INTERVAL{this}{unit}"
def return_sql(self, expression: exp.Return) -> str:
@@ -1664,7 +1757,7 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f"({expressions})" if expressions else ""
- options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"REFERENCES {this}{expressions}{options}"
@@ -1690,9 +1783,9 @@ class Generator:
return f"NOT {self.sql(expression, 'this')}"
def alias_sql(self, expression: exp.Alias) -> str:
- to_sql = self.sql(expression, "alias")
- to_sql = f" AS {to_sql}" if to_sql else ""
- return f"{self.sql(expression, 'this')}{to_sql}"
+ alias = self.sql(expression, "alias")
+ alias = f" AS {alias}" if alias else ""
+ return f"{self.sql(expression, 'this')}{alias}"
def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
@@ -1712,7 +1805,11 @@ class Generator:
if not self.pretty:
return self.binary(expression, op)
- sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
+ sqls = tuple(
+ self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
+ for i, e in enumerate(expression.flatten(unnest=False))
+ )
+
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
return f"{sep}{op} ".join(sqls)
@@ -1797,13 +1894,13 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
- actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
+ actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
- actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
+ actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
- actions = self.expressions(expression, "actions", flat=True)
+ actions = self.expressions(expression, key="actions", flat=True)
else:
- actions = self.expressions(expression, "actions")
+ actions = self.expressions(expression, key="actions")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
@@ -1935,6 +2032,7 @@ class Generator:
return f"USE{kind}{this}"
def binary(self, expression: exp.Binary, op: str) -> str:
+ op = self.maybe_comment(op, comments=expression.comments)
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression: exp.Func) -> str:
@@ -1965,14 +2063,15 @@ class Generator:
def expressions(
self,
- expression: exp.Expression,
+ expression: t.Optional[exp.Expression] = None,
key: t.Optional[str] = None,
+ sqls: t.Optional[t.List[str]] = None,
flat: bool = False,
indent: bool = True,
sep: str = ", ",
prefix: str = "",
) -> str:
- expressions = expression.args.get(key or "expressions")
+ expressions = expression.args.get(key or "expressions") if expression else sqls
if not expressions:
return ""
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index d44d7dd..b2f0520 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -131,11 +131,16 @@ def subclasses(
]
-def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
+def apply_index_offset(
+ this: exp.Expression,
+ expressions: t.List[t.Optional[E]],
+ offset: int,
+) -> t.List[t.Optional[E]]:
"""
Applies an offset to a given integer literal expression.
Args:
+ this: the target of the index
expressions: the expression the offset will be applied to, wrapped in a list.
offset: the offset that will be applied.
@@ -148,11 +153,28 @@ def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.Lis
expression = expressions[0]
- if expression and expression.is_int:
- expression = expression.copy()
- logger.warning("Applying array index offset (%s)", offset)
- expression.args["this"] = str(int(expression.this) + offset) # type: ignore
- return [expression]
+ from sqlglot import exp
+ from sqlglot.optimizer.annotate_types import annotate_types
+ from sqlglot.optimizer.simplify import simplify
+
+ if not this.type:
+ annotate_types(this)
+
+ if t.cast(exp.DataType, this.type).this not in (
+ exp.DataType.Type.UNKNOWN,
+ exp.DataType.Type.ARRAY,
+ ):
+ return expressions
+
+ if expression:
+ if not expression.type:
+ annotate_types(expression)
+ if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
+ logger.warning("Applying array index offset (%s)", offset)
+ expression = simplify(
+ exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
+ )
+ return [expression]
return expressions
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 2e563ae..0eac870 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -20,6 +20,7 @@ class Node:
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
+ alias: str = ""
def walk(self) -> t.Iterator[Node]:
yield self
@@ -69,14 +70,19 @@ def lineage(
optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)
- tables: t.Dict[str, Node] = {}
def to_node(
column_name: str,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
+ alias: t.Optional[str] = None,
) -> Node:
+ aliases = {
+ dt.alias: dt.comments[0].split()[1]
+ for dt in scope.derived_tables
+ if dt.comments and dt.comments[0].startswith("source: ")
+ }
if isinstance(scope.expression, exp.Union):
for scope in scope.union_scopes:
node = to_node(
@@ -84,37 +90,58 @@ def lineage(
scope=scope,
scope_name=scope_name,
upstream=upstream,
+ alias=aliases.get(scope_name),
)
return node
- select = next(select for select in scope.selects if select.alias_or_name == column_name)
- source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
- select = source.selects[0]
+ # Find the specific select clause that is the source of the column we want.
+ # This can either be a specific, named select or a generic `*` clause.
+ select = next(
+ (select for select in scope.selects if select.alias_or_name == column_name),
+ exp.Star() if scope.expression.is_star else None,
+ )
+ if not select:
+ raise ValueError(f"Could not find {column_name} in {scope.expression}")
+
+ if isinstance(scope.expression, exp.Select):
+ # For better ergonomics in our node labels, replace the full select with
+ # a version that has only the column we care about.
+ # "x", SELECT x, y FROM foo
+ # => "x", SELECT x FROM foo
+ source = optimize(
+ scope.expression.select(select, append=False), schema=schema, rules=rules
+ )
+ select = source.selects[0]
+ else:
+ source = scope.expression
+
+ # Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column_name}" if scope_name else column_name,
source=source,
expression=select,
+ alias=alias or "",
)
-
if upstream:
upstream.downstream.append(node)
+ # Find all columns that went into creating this one to list their lineage nodes.
for c in set(select.find_all(exp.Column)):
table = c.table
- source = scope.sources[table]
+ source = scope.sources.get(table)
if isinstance(source, Scope):
+ # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
- c.name,
- scope=source,
- scope_name=table,
- upstream=node,
+ c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
)
else:
- if table not in tables:
- tables[table] = Node(name=c.sql(), source=source, expression=source)
- node.downstream.append(tables[table])
+ # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
+ # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
+ # is not passed into the `sources` map.
+ source = source or exp.Placeholder()
+ node.downstream.append(Node(name=c.sql(), source=source, expression=source))
return node
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 99888c6..6238759 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -116,6 +116,9 @@ class TypeAnnotator:
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
+ exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
+ exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
@@ -335,7 +338,7 @@ class TypeAnnotator:
left_type = expression.left.type.this
right_type = expression.right.type.this
- if isinstance(expression, (exp.And, exp.Or)):
+ if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
@@ -344,7 +347,7 @@ class TypeAnnotator:
)
else:
expression.type = exp.DataType.Type.BOOLEAN
- elif isinstance(expression, (exp.Condition, exp.Predicate)):
+ elif isinstance(expression, exp.Predicate):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(left_type, right_type)
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index f2df230..40668ef 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -46,7 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
root = node is expression
original = node.copy()
try:
- node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ node = node.replace(
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ )
except OptimizeError as e:
logger.info(e)
node.replace(original)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 6eae2b5..0a31246 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -93,6 +93,7 @@ def _expand_using(scope, resolver):
if column not in columns:
columns[column] = k
+ source_table = ordered[-1]
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
@@ -102,8 +103,10 @@ def _expand_using(scope, resolver):
table = columns.get(identifier)
if not table or identifier not in join_columns:
- raise OptimizeError(f"Cannot automatically join: {identifier}")
+ if columns and join_columns:
+ raise OptimizeError(f"Cannot automatically join: {identifier}")
+ table = table or source_table
conditions.append(
exp.condition(
exp.EQ(
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 93e1179..a719ebe 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -65,5 +65,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not table_alias.name:
table_alias.set("this", next_name())
+ if isinstance(udtf, exp.Values) and not table_alias.columns:
+ for i, e in enumerate(udtf.expressions[0].expressions):
+ table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
return expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 28ae86d..4e6c910 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -201,23 +201,24 @@ def _simplify_comparison(expression, left, right, or_=False):
return left if (av < bv if or_ else av >= bv) else right
# we can't ever shortcut to true because the column could be null
- if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
- if not or_ and av <= bv:
- return exp.false()
- elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
- if not or_ and av >= bv:
- return exp.false()
- elif isinstance(a, exp.EQ):
- if isinstance(b, exp.LT):
- return exp.false() if av >= bv else a
- if isinstance(b, exp.LTE):
- return exp.false() if av > bv else a
- if isinstance(b, exp.GT):
- return exp.false() if av <= bv else a
- if isinstance(b, exp.GTE):
- return exp.false() if av < bv else a
- if isinstance(b, exp.NEQ):
- return exp.false() if av == bv else a
+ if not or_:
+ if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
+ if av <= bv:
+ return exp.false()
+ elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
+ if av >= bv:
+ return exp.false()
+ elif isinstance(a, exp.EQ):
+ if isinstance(b, exp.LT):
+ return exp.false() if av >= bv else a
+ if isinstance(b, exp.LTE):
+ return exp.false() if av > bv else a
+ if isinstance(b, exp.GT):
+ return exp.false() if av <= bv else a
+ if isinstance(b, exp.GTE):
+ return exp.false() if av < bv else a
+ if isinstance(b, exp.NEQ):
+ return exp.false() if av == bv else a
return None
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index b3b899c..abb23ad 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -18,8 +18,13 @@ from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
+E = t.TypeVar("E", bound=exp.Expression)
+
def parse_var_map(args: t.Sequence) -> exp.Expression:
+ if len(args) == 1 and args[0].is_star:
+ return exp.StarMap(this=args[0])
+
keys = []
values = []
for i in range(0, len(args), 2):
@@ -108,6 +113,8 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
+ JOIN_HINTS: t.Set[str] = set()
+
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
@@ -145,6 +152,7 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATE,
TokenType.DECIMAL,
+ TokenType.BIGDECIMAL,
TokenType.UUID,
TokenType.GEOGRAPHY,
TokenType.GEOMETRY,
@@ -221,8 +229,10 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT,
TokenType.FULL,
TokenType.IF,
+ TokenType.IS,
TokenType.ISNULL,
TokenType.INTERVAL,
+ TokenType.KEEP,
TokenType.LAZY,
TokenType.LEADING,
TokenType.LEFT,
@@ -235,6 +245,7 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
+ TokenType.OVERWRITE,
TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
@@ -266,6 +277,8 @@ class Parser(metaclass=_Parser):
*NO_PAREN_FUNCTIONS,
}
+ INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END}
+
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.FULL,
@@ -276,6 +289,8 @@ class Parser(metaclass=_Parser):
TokenType.WINDOW,
}
+ COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS}
+
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
@@ -400,7 +415,7 @@ class Parser(metaclass=_Parser):
COLUMN_OPERATORS = {
TokenType.DOT: None,
TokenType.DCOLON: lambda self, this, to: self.expression(
- exp.Cast,
+ exp.Cast if self.STRICT_CAST else exp.TryCast,
this=this,
to=to,
),
@@ -560,7 +575,7 @@ class Parser(metaclass=_Parser):
),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
@@ -571,7 +586,7 @@ class Parser(metaclass=_Parser):
"FREESPACE": lambda self: self._parse_freespace(),
"GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"JOURNAL": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
@@ -600,20 +615,20 @@ class Parser(metaclass=_Parser):
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
+ "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("STABLE")
+ exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
- "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
+ "STORED": lambda self: self._parse_stored(),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
+ "TEMP": lambda self: self._parse_temporary(global_=False),
"TEMPORARY": lambda self: self._parse_temporary(global_=False),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
- "VOLATILE": lambda self: self.expression(
- exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
- ),
+ "VOLATILE": lambda self: self._parse_volatile_property(),
"WITH": lambda self: self._parse_with_property(),
}
@@ -648,8 +663,11 @@ class Parser(metaclass=_Parser):
"LIKE": lambda self: self._parse_create_like(),
"NOT": lambda self: self._parse_not_constraint(),
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
+ "ON": lambda self: self._match(TokenType.UPDATE)
+ and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
+ "REFERENCES": lambda self: self._parse_references(match=False),
"TITLE": lambda self: self.expression(
exp.TitleColumnConstraint, this=self._parse_var_or_string()
),
@@ -668,9 +686,14 @@ class Parser(metaclass=_Parser):
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
NO_PAREN_FUNCTION_PARSERS = {
+ TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
TokenType.CASE: lambda self: self._parse_case(),
TokenType.IF: lambda self: self._parse_if(),
- TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
+ TokenType.NEXT_VALUE_FOR: lambda self: self.expression(
+ exp.NextValueFor,
+ this=self._parse_column(),
+ order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
+ ),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -715,6 +738,8 @@ class Parser(metaclass=_Parser):
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
+ TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {}
+
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
@@ -731,6 +756,7 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
+ WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@@ -738,6 +764,9 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
+ QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
+ PREFIXED_PIVOT_COLUMNS = False
+
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
@@ -895,8 +924,8 @@ class Parser(metaclass=_Parser):
error level setting.
"""
token = token or self._curr or self._prev or Token.string("")
- start = self._find_token(token)
- end = start + len(token.text)
+ start = token.start
+ end = token.end
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
@@ -918,8 +947,8 @@ class Parser(metaclass=_Parser):
self.errors.append(error)
def expression(
- self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
- ) -> exp.Expression:
+ self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs
+ ) -> E:
"""
Creates a new, validated Expression.
@@ -958,22 +987,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
- return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
-
- def _find_token(self, token: Token) -> int:
- line = 1
- col = 1
- index = 0
-
- while line < token.line or col < token.col:
- if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK:
- line += 1
- col = 1
- else:
- col += 1
- index += 1
-
- return index
+ return self.sql[start.start : end.end]
def _advance(self, times: int = 1) -> None:
self._index += times
@@ -990,7 +1004,7 @@ class Parser(metaclass=_Parser):
if index != self._index:
self._advance(index - self._index)
- def _parse_command(self) -> exp.Expression:
+ def _parse_command(self) -> exp.Command:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
@@ -1007,7 +1021,7 @@ class Parser(metaclass=_Parser):
if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=kind.token_type)
elif kind.token_type == TokenType.TABLE:
- this = self._parse_table()
+ this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS)
elif kind.token_type == TokenType.COLUMN:
this = self._parse_column()
else:
@@ -1035,16 +1049,13 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
- def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
+ def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
- if default_kind:
- kind = default_kind
- else:
- return self._parse_as_command(start)
+ return self._parse_as_command(start)
return self.expression(
exp.Drop,
@@ -1055,6 +1066,7 @@ class Parser(metaclass=_Parser):
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
constraints=self._match_text_seq("CONSTRAINTS"),
+ purge=self._match_text_seq("PURGE"),
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@@ -1070,7 +1082,6 @@ class Parser(metaclass=_Parser):
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
- volatile = self._match(TokenType.VOLATILE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
@@ -1179,7 +1190,6 @@ class Parser(metaclass=_Parser):
kind=create_token.text,
replace=replace,
unique=unique,
- volatile=volatile,
expression=expression,
exists=exists,
properties=properties,
@@ -1225,6 +1235,21 @@ class Parser(metaclass=_Parser):
return None
+ def _parse_stored(self) -> exp.Expression:
+ self._match(TokenType.ALIAS)
+
+ input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None
+ output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None
+
+ return self.expression(
+ exp.FileFormatProperty,
+ this=self.expression(
+ exp.InputOutputFormat, input_format=input_format, output_format=output_format
+ )
+ if input_format or output_format
+ else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
+ )
+
def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
@@ -1258,6 +1283,21 @@ class Parser(metaclass=_Parser):
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
+ def _parse_volatile_property(self) -> exp.Expression:
+ if self._index >= 2:
+ pre_volatile_token = self._tokens[self._index - 2]
+ else:
+ pre_volatile_token = None
+
+ if pre_volatile_token and pre_volatile_token.token_type in (
+ TokenType.CREATE,
+ TokenType.REPLACE,
+ TokenType.UNIQUE,
+ ):
+ return exp.VolatileProperty()
+
+ return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
+
def _parse_with_property(
self,
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
@@ -1574,11 +1614,46 @@ class Parser(metaclass=_Parser):
exists=self._parse_exists(),
partition=self._parse_partition(),
expression=self._parse_ddl_select(),
+ conflict=self._parse_on_conflict(),
returning=self._parse_returning(),
overwrite=overwrite,
alternative=alternative,
)
+ def _parse_on_conflict(self) -> t.Optional[exp.Expression]:
+ conflict = self._match_text_seq("ON", "CONFLICT")
+ duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
+
+ if not (conflict or duplicate):
+ return None
+
+ nothing = None
+ expressions = None
+ key = None
+ constraint = None
+
+ if conflict:
+ if self._match_text_seq("ON", "CONSTRAINT"):
+ constraint = self._parse_id_var()
+ else:
+ key = self._parse_csv(self._parse_value)
+
+ self._match_text_seq("DO")
+ if self._match_text_seq("NOTHING"):
+ nothing = True
+ else:
+ self._match(TokenType.UPDATE)
+ expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
+
+ return self.expression(
+ exp.OnConflict,
+ duplicate=duplicate,
+ expressions=expressions,
+ nothing=nothing,
+ key=key,
+ constraint=constraint,
+ )
+
def _parse_returning(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.RETURNING):
return None
@@ -1639,7 +1714,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
- this=self._parse_table(schema=True),
+ this=self._parse_table(),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
returning=self._parse_returning(),
@@ -1792,6 +1867,7 @@ class Parser(metaclass=_Parser):
if not skip_with_token and not self._match(TokenType.WITH):
return None
+ comments = self._prev_comments
recursive = self._match(TokenType.RECURSIVE)
expressions = []
@@ -1803,7 +1879,9 @@ class Parser(metaclass=_Parser):
else:
self._match(TokenType.WITH)
- return self.expression(exp.With, expressions=expressions, recursive=recursive)
+ return self.expression(
+ exp.With, comments=comments, expressions=expressions, recursive=recursive
+ )
def _parse_cte(self) -> exp.Expression:
alias = self._parse_table_alias()
@@ -1856,15 +1934,20 @@ class Parser(metaclass=_Parser):
table = isinstance(this, exp.Table)
while True:
- lateral = self._parse_lateral()
join = self._parse_join()
- comma = None if table else self._match(TokenType.COMMA)
- if lateral:
- this.append("laterals", lateral)
if join:
this.append("joins", join)
+
+ lateral = None
+ if not join:
+ lateral = self._parse_lateral()
+ if lateral:
+ this.append("laterals", lateral)
+
+ comma = None if table else self._match(TokenType.COMMA)
if comma:
this.args["from"].append("expressions", self._parse_table())
+
if not (lateral or join or comma):
break
@@ -1906,14 +1989,13 @@ class Parser(metaclass=_Parser):
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.MATCH_RECOGNIZE):
return None
+
self._match_l_paren()
partition = self._parse_partition_by()
order = self._parse_order()
measures = (
- self._parse_alias(self._parse_conjunction())
- if self._match_text_seq("MEASURES")
- else None
+ self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None
)
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
@@ -1967,8 +2049,17 @@ class Parser(metaclass=_Parser):
pattern = None
define = (
- self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None
+ self._parse_csv(
+ lambda: self.expression(
+ exp.Alias,
+ alias=self._parse_id_var(any_token=True),
+ this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
+ )
+ )
+ if self._match_text_seq("DEFINE")
+ else None
)
+
self._match_r_paren()
return self.expression(
@@ -1980,6 +2071,7 @@ class Parser(metaclass=_Parser):
after=after,
pattern=pattern,
define=define,
+ alias=self._parse_table_alias(),
)
def _parse_lateral(self) -> t.Optional[exp.Expression]:
@@ -2022,9 +2114,6 @@ class Parser(metaclass=_Parser):
alias=table_alias,
)
- if outer_apply or cross_apply:
- return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
-
return expression
def _parse_join_side_and_kind(
@@ -2037,11 +2126,26 @@ class Parser(metaclass=_Parser):
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ index = self._index
natural, side, kind = self._parse_join_side_and_kind()
+ hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
+ join = self._match(TokenType.JOIN)
- if not skip_join_token and not self._match(TokenType.JOIN):
+ if not skip_join_token and not join:
+ self._retreat(index)
+ kind = None
+ natural = None
+ side = None
+
+ outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
+ cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False)
+
+ if not skip_join_token and not join and not outer_apply and not cross_apply:
return None
+ if outer_apply:
+ side = Token(TokenType.LEFT, "LEFT")
+
kwargs: t.Dict[
str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
] = {"this": self._parse_table()}
@@ -2052,6 +2156,8 @@ class Parser(metaclass=_Parser):
kwargs["side"] = side.text
if kind:
kwargs["kind"] = kind.text
+ if hint:
+ kwargs["hint"] = hint
if self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
@@ -2179,7 +2285,7 @@ class Parser(metaclass=_Parser):
return None
expressions = self._parse_wrapped_csv(self._parse_column)
- ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
+ ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
alias = self._parse_table_alias()
if alias and self.unnest_column_only:
@@ -2191,7 +2297,7 @@ class Parser(metaclass=_Parser):
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
- offset = self._parse_conjunction()
+ offset = self._parse_id_var() or exp.Identifier(this="offset")
return self.expression(
exp.Unnest,
@@ -2294,6 +2400,9 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
+ if not expressions:
+ self.raise_error("Failed to parse PIVOT's aggregation list")
+
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
@@ -2311,8 +2420,26 @@ class Parser(metaclass=_Parser):
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
+ if not unpivot:
+ names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
+
+ columns: t.List[exp.Expression] = []
+ for col in pivot.args["field"].expressions:
+ for name in names:
+ if self.PREFIXED_PIVOT_COLUMNS:
+ name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
+ else:
+ name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
+
+ columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
+
+ pivot.set("columns", columns)
+
return pivot
+ def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
+ return [agg.alias for agg in pivot_columns]
+
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
@@ -2433,10 +2560,25 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
+
count = self._parse_number()
+ percent = self._match(TokenType.PERCENT)
+
self._match_set((TokenType.ROW, TokenType.ROWS))
- self._match(TokenType.ONLY)
- return self.expression(exp.Fetch, direction=direction, count=count)
+
+ only = self._match(TokenType.ONLY)
+ with_ties = self._match_text_seq("WITH", "TIES")
+
+ if only and with_ties:
+ self.raise_error("Cannot specify both ONLY and WITH TIES in FETCH clause")
+
+ return self.expression(
+ exp.Fetch,
+ direction=direction,
+ count=count,
+ percent=percent,
+ with_ties=with_ties,
+ )
return this
@@ -2493,7 +2635,11 @@ class Parser(metaclass=_Parser):
negate = self._match(TokenType.NOT)
if self._match_set(self.RANGE_PARSERS):
- this = self.RANGE_PARSERS[self._prev.token_type](self, this)
+ expression = self.RANGE_PARSERS[self._prev.token_type](self, this)
+ if not expression:
+ return this
+
+ this = expression
elif self._match(TokenType.ISNULL):
this = self.expression(exp.Is, this=this, expression=exp.Null())
@@ -2511,17 +2657,19 @@ class Parser(metaclass=_Parser):
return this
- def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression:
+ def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ index = self._index - 1
negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
- this = self.expression(
- exp.Is,
- this=this,
- expression=self._parse_null() or self._parse_boolean(),
- )
+ expression = self._parse_null() or self._parse_boolean()
+ if not expression:
+ self._retreat(index)
+ return None
+
+ this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this
def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
@@ -2553,6 +2701,27 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
+ def _parse_interval(self) -> t.Optional[exp.Expression]:
+ if not self._match(TokenType.INTERVAL):
+ return None
+
+ this = self._parse_primary() or self._parse_term()
+ unit = self._parse_function() or self._parse_var()
+
+ # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
+ # each INTERVAL expression into this canonical form so it's easy to transpile
+ if this and isinstance(this, exp.Literal):
+ if this.is_number:
+ this = exp.Literal.string(this.name)
+
+ # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year'
+ parts = this.name.split()
+ if not unit and len(parts) <= 2:
+ this = exp.Literal.string(seq_get(parts, 0))
+ unit = self.expression(exp.Var, this=seq_get(parts, 1))
+
+ return self.expression(exp.Interval, this=this, unit=unit)
+
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
@@ -2588,20 +2757,24 @@ class Parser(metaclass=_Parser):
return self._parse_at_time_zone(self._parse_type())
def _parse_type(self) -> t.Optional[exp.Expression]:
- if self._match(TokenType.INTERVAL):
- return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field())
+ interval = self._parse_interval()
+ if interval:
+ return interval
index = self._index
- type_token = self._parse_types(check_func=True)
+ data_type = self._parse_types(check_func=True)
this = self._parse_column()
- if type_token:
+ if data_type:
if isinstance(this, exp.Literal):
- return self.expression(exp.Cast, this=this, to=type_token)
- if not type_token.args.get("expressions"):
+ parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
+ if parser:
+ return parser(self, this, data_type)
+ return self.expression(exp.Cast, this=this, to=data_type)
+ if not data_type.args.get("expressions"):
self._retreat(index)
return self._parse_column()
- return type_token
+ return data_type
return this
@@ -2631,11 +2804,10 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(self._parse_conjunction)
- if not expressions:
+ if not expressions or not self._match(TokenType.R_PAREN):
self._retreat(index)
return None
- self._match_r_paren()
maybe_func = True
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
@@ -2720,15 +2892,14 @@ class Parser(metaclass=_Parser):
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
- if self._curr and self._curr.token_type in self.TYPE_TOKENS:
- return self._parse_types()
-
+ index = self._index
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
if not data_type:
- return None
+ self._retreat(index)
+ return self._parse_types()
return self.expression(exp.StructKwarg, this=this, expression=data_type)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -2825,6 +2996,7 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
self._match_r_paren()
+ comments.extend(self._prev_comments)
if this and comments:
this.comments = comments
@@ -2833,8 +3005,16 @@ class Parser(metaclass=_Parser):
return None
- def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]:
- return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
+ def _parse_field(
+ self,
+ any_token: bool = False,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
+ return (
+ self._parse_primary()
+ or self._parse_function()
+ or self._parse_id_var(any_token=any_token, tokens=tokens)
+ )
def _parse_function(
self, functions: t.Optional[t.Dict[str, t.Callable]] = None
@@ -3079,12 +3259,10 @@ class Parser(metaclass=_Parser):
return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
- this = self._parse_references()
- if this:
- return this
-
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
+ else:
+ this = None
if self._match_texts(self.CONSTRAINT_PARSERS):
return self.expression(
@@ -3164,8 +3342,8 @@ class Parser(metaclass=_Parser):
return options
- def _parse_references(self) -> t.Optional[exp.Expression]:
- if not self._match(TokenType.REFERENCES):
+ def _parse_references(self, match=True) -> t.Optional[exp.Expression]:
+ if match and not self._match(TokenType.REFERENCES):
return None
expressions = None
@@ -3234,7 +3412,7 @@ class Parser(metaclass=_Parser):
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
- expressions = apply_index_offset(expressions, -self.index_offset)
+ expressions = apply_index_offset(this, expressions, -self.index_offset)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
@@ -3279,7 +3457,13 @@ class Parser(metaclass=_Parser):
self.validate_expression(this, args)
self._match_r_paren()
else:
+ index = self._index - 1
condition = self._parse_conjunction()
+
+ if not condition:
+ self._retreat(index)
+ return None
+
self._match(TokenType.THEN)
true = self._parse_conjunction()
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
@@ -3591,14 +3775,24 @@ class Parser(metaclass=_Parser):
# bigquery select from window x AS (partition by ...)
if alias:
+ over = None
self._match(TokenType.ALIAS)
- elif not self._match(TokenType.OVER):
+ elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS):
return this
+ else:
+ over = self._prev.text.upper()
if not self._match(TokenType.L_PAREN):
- return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
+ return self.expression(
+ exp.Window, this=this, alias=self._parse_id_var(False), over=over
+ )
window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
+
+ first = self._match(TokenType.FIRST)
+ if self._match_text_seq("LAST"):
+ first = False
+
partition = self._parse_partition_by()
order = self._parse_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
@@ -3629,6 +3823,8 @@ class Parser(metaclass=_Parser):
order=order,
spec=spec,
alias=window_alias,
+ over=over,
+ first=first,
)
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
@@ -3886,7 +4082,10 @@ class Parser(metaclass=_Parser):
return expression
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
- return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
+ drop = self._match(TokenType.DROP) and self._parse_drop()
+ if drop and not isinstance(drop, exp.Command):
+ drop.set("kind", drop.args.get("kind", "COLUMN"))
+ return drop
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
@@ -4010,7 +4209,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.INSERT):
_this = self._parse_star()
if _this:
- then = self.expression(exp.Insert, this=_this)
+ then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this)
else:
then = self.expression(
exp.Insert,
@@ -4239,5 +4438,8 @@ class Parser(metaclass=_Parser):
break
parent = parent.parent
else:
- column.replace(dot_or_id)
+ if column is node:
+ node = dot_or_id
+ else:
+ column.replace(dot_or_id)
return node
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 8e39c7f..5d60eb9 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.errors import SchemaError
+from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
@@ -75,12 +75,11 @@ class AbstractMappingSchema(t.Generic[T]):
mapping: dict | None = None,
) -> None:
self.mapping = mapping or {}
- self.mapping_trie = self._build_trie(self.mapping)
+ self.mapping_trie = new_trie(
+ tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
+ )
self._supported_table_args: t.Tuple[str, ...] = tuple()
- def _build_trie(self, schema: t.Dict) -> t.Dict:
- return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
-
def _depth(self) -> int:
return dict_depth(self.mapping)
@@ -179,6 +178,64 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
}
)
+ def add_table(
+ self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ ) -> None:
+ """
+ Register or update a table. Updates are only performed if a new column mapping is provided.
+
+ Args:
+ table: the `Table` expression instance or string representing the table.
+ column_mapping: a column mapping that describes the structure of the table.
+ """
+ normalized_table = self._normalize_table(self._ensure_table(table))
+ normalized_column_mapping = {
+ self._normalize_name(key): value
+ for key, value in ensure_column_mapping(column_mapping).items()
+ }
+
+ schema = self.find(normalized_table, raise_on_missing=False)
+ if schema and not normalized_column_mapping:
+ return
+
+ parts = self.table_parts(normalized_table)
+
+ _nested_set(
+ self.mapping,
+ tuple(reversed(parts)),
+ normalized_column_mapping,
+ )
+ new_trie([parts], self.mapping_trie)
+
+ def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
+ table_ = self._normalize_table(self._ensure_table(table))
+ schema = self.find(table_)
+
+ if schema is None:
+ return []
+
+ if not only_visible or not self.visible:
+ return list(schema)
+
+ visible = self._nested_get(self.table_parts(table_), self.visible)
+ return [col for col in schema if col in visible] # type: ignore
+
+ def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
+ column_name = self._normalize_name(column if isinstance(column, str) else column.this)
+ table_ = self._normalize_table(self._ensure_table(table))
+
+ table_schema = self.find(table_, raise_on_missing=False)
+ if table_schema:
+ column_type = table_schema.get(column_name)
+
+ if isinstance(column_type, exp.DataType):
+ return column_type
+ elif isinstance(column_type, str):
+ return self._to_data_type(column_type.upper())
+ raise SchemaError(f"Unknown column type '{column_type}'")
+
+ return exp.DataType.build("unknown")
+
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
@@ -206,84 +263,37 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
- def add_table(
- self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
- ) -> None:
- """
- Register or update a table. Updates are only performed if a new column mapping is provided.
+ def _normalize_table(self, table: exp.Table) -> exp.Table:
+ normalized_table = table.copy()
+ for arg in TABLE_ARGS:
+ value = normalized_table.args.get(arg)
+ if isinstance(value, (str, exp.Identifier)):
+ normalized_table.set(arg, self._normalize_name(value))
- Args:
- table: the `Table` expression instance or string representing the table.
- column_mapping: a column mapping that describes the structure of the table.
- """
- table_ = self._ensure_table(table)
- column_mapping = ensure_column_mapping(column_mapping)
- schema = self.find(table_, raise_on_missing=False)
-
- if schema and not column_mapping:
- return
-
- _nested_set(
- self.mapping,
- list(reversed(self.table_parts(table_))),
- column_mapping,
- )
- self.mapping_trie = self._build_trie(self.mapping)
+ return normalized_table
- def _normalize_name(self, name: str) -> str:
+ def _normalize_name(self, name: str | exp.Identifier) -> str:
try:
- identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
- name, read=self.dialect, into=exp.Identifier
- )
- except:
- identifier = exp.to_identifier(name)
- assert isinstance(identifier, exp.Identifier)
+ identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier)
+ except ParseError:
+ return name if isinstance(name, str) else name.name
- if identifier.quoted:
- return identifier.name
- return identifier.name.lower()
+ return identifier.name if identifier.quoted else identifier.name.lower()
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
- table_ = exp.to_table(table)
+ if isinstance(table, exp.Table):
+ return table
+ table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table)
if not table_:
raise SchemaError(f"Not a valid table '{table}'")
return table_
- def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
- table_ = self._ensure_table(table)
- schema = self.find(table_)
-
- if schema is None:
- return []
-
- if not only_visible or not self.visible:
- return list(schema)
-
- visible = self._nested_get(self.table_parts(table_), self.visible)
- return [col for col in schema if col in visible] # type: ignore
-
- def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
- column_name = column if isinstance(column, str) else column.name
- table_ = exp.to_table(table)
- if table_:
- table_schema = self.find(table_, raise_on_missing=False)
- if table_schema:
- column_type = table_schema.get(column_name)
-
- if isinstance(column_type, exp.DataType):
- return column_type
- elif isinstance(column_type, str):
- return self._to_data_type(column_type.upper())
- raise SchemaError(f"Unknown column type '{column_type}'")
- return exp.DataType(this=exp.DataType.Type.UNKNOWN)
- raise SchemaError(f"Could not convert table '{table}'")
-
def _to_data_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@@ -313,7 +323,7 @@ def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
return MappingSchema(schema, dialect=dialect)
-def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
+def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
@@ -371,7 +381,7 @@ def _nested_get(
return d
-def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
+def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
@@ -384,11 +394,11 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
Args:
d: dictionary to update.
- keys: the keys that makeup the path to `value`.
- value: the value to set in the dictionary for the given key path.
+ keys: the keys that makeup the path to `value`.
+ value: the value to set in the dictionary for the given key path.
- Returns:
- The (possibly) updated dictionary.
+ Returns:
+ The (possibly) updated dictionary.
"""
if not keys:
return d
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index cf2e31f..64c1f92 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -87,6 +87,7 @@ class TokenType(AutoName):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
+ BIGDECIMAL = auto()
CHAR = auto()
NCHAR = auto()
VARCHAR = auto()
@@ -214,6 +215,7 @@ class TokenType(AutoName):
ISNULL = auto()
JOIN = auto()
JOIN_MARKER = auto()
+ KEEP = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
@@ -231,6 +233,7 @@ class TokenType(AutoName):
MOD = auto()
NATURAL = auto()
NEXT = auto()
+ NEXT_VALUE_FOR = auto()
NO_ACTION = auto()
NOTNULL = auto()
NULL = auto()
@@ -315,7 +318,7 @@ class TokenType(AutoName):
class Token:
- __slots__ = ("token_type", "text", "line", "col", "comments")
+ __slots__ = ("token_type", "text", "line", "col", "end", "comments")
@classmethod
def number(cls, number: int) -> Token:
@@ -343,22 +346,29 @@ class Token:
text: str,
line: int = 1,
col: int = 1,
+ end: int = 0,
comments: t.List[str] = [],
) -> None:
self.token_type = token_type
self.text = text
self.line = line
- self.col = col - len(text)
- self.col = self.col if self.col > 1 else 1
+ size = len(text)
+ self.col = col
+ self.end = end if end else size
self.comments = comments
+ @property
+ def start(self) -> int:
+ """Returns the start of the token."""
+ return self.end - len(self.text)
+
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
- def __new__(cls, clsname, bases, attrs): # type: ignore
+ def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = {
@@ -433,25 +443,25 @@ class Tokenizer(metaclass=_Tokenizer):
"#": TokenType.HASH,
}
- QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
-
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
-
- HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
-
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
-
+ HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
-
- STRING_ESCAPES = ["'"]
-
- _STRING_ESCAPES: t.Set[str] = set()
-
IDENTIFIER_ESCAPES = ['"']
+ QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
+ STRING_ESCAPES = ["'"]
+ VAR_SINGLE_TOKENS: t.Set[str] = set()
+ _COMMENTS: t.Dict[str, str] = {}
+ _BIT_STRINGS: t.Dict[str, str] = {}
+ _BYTE_STRINGS: t.Dict[str, str] = {}
+ _HEX_STRINGS: t.Dict[str, str] = {}
+ _IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
+ _QUOTES: t.Dict[str, str] = {}
+ _STRING_ESCAPES: t.Set[str] = set()
- KEYWORDS = {
+ KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
"{{+": TokenType.BLOCK_START,
@@ -553,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
+ "KEEP": TokenType.KEEP,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
@@ -565,6 +576,7 @@ class Tokenizer(metaclass=_Tokenizer):
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
+ "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
"NOTNULL": TokenType.NOTNULL,
@@ -632,6 +644,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
+ "UUID": TokenType.UUID,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@@ -661,6 +674,8 @@ class Tokenizer(metaclass=_Tokenizer):
"INT8": TokenType.BIGINT,
"DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
+ "BIGDECIMAL": TokenType.BIGDECIMAL,
+ "BIGNUMERIC": TokenType.BIGDECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
"NUMBER": TokenType.DECIMAL,
@@ -742,7 +757,7 @@ class Tokenizer(metaclass=_Tokenizer):
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
- KEYWORD_TRIE = None # autofilled
+ KEYWORD_TRIE: t.Dict = {} # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
@@ -776,19 +791,28 @@ class Tokenizer(metaclass=_Tokenizer):
self._col = 1
self._comments: t.List[str] = []
- self._char = None
- self._end = None
- self._peek = None
+ self._char = ""
+ self._end = False
+ self._peek = ""
self._prev_token_line = -1
self._prev_token_comments: t.List[str] = []
- self._prev_token_type = None
+ self._prev_token_type: t.Optional[TokenType] = None
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset()
self.sql = sql
self.size = len(sql)
- self._scan()
+ try:
+ self._scan()
+ except Exception as e:
+ start = self._current - 50
+ end = self._current + 50
+ start = start if start > 0 else 0
+ end = end if end < self.size else self.size - 1
+ context = self.sql[start:end]
+ raise ValueError(f"Error tokenizing '{context}'") from e
+
return self.tokens
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
@@ -810,9 +834,12 @@ class Tokenizer(metaclass=_Tokenizer):
if until and until():
break
+ if self.tokens:
+ self.tokens[-1].comments.extend(self._comments)
+
def _chars(self, size: int) -> str:
if size == 1:
- return self._char # type: ignore
+ return self._char
start = self._current - 1
end = start + size
if end <= self.size:
@@ -821,17 +848,15 @@ class Tokenizer(metaclass=_Tokenizer):
def _advance(self, i: int = 1) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
- self._set_new_line()
+ self._col = 1
+ self._line += 1
+ else:
+ self._col += i
- self._col += i
self._current += i
- self._end = self._current >= self.size # type: ignore
- self._char = self.sql[self._current - 1] # type: ignore
- self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
-
- def _set_new_line(self) -> None:
- self._col = 1
- self._line += 1
+ self._end = self._current >= self.size
+ self._char = self.sql[self._current - 1]
+ self._peek = "" if self._end else self.sql[self._current]
@property
def _text(self) -> str:
@@ -840,13 +865,14 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comments = self._comments
- self._prev_token_type = token_type # type: ignore
+ self._prev_token_type = token_type
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
+ self._current,
self._comments,
)
)
@@ -881,7 +907,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip:
result = 1
else:
- result, trie = in_trie(trie, char.upper()) # type: ignore
+ result, trie = in_trie(trie, char.upper())
if result == 0:
break
@@ -910,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
if not word:
if self._char in self.SINGLE_TOKENS:
- self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore
+ self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return
self._scan_var()
return
@@ -927,29 +953,31 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(self.KEYWORDS[word], text=word)
def _scan_comment(self, comment_start: str) -> bool:
- if comment_start not in self._COMMENTS: # type: ignore
+ if comment_start not in self._COMMENTS:
return False
comment_start_line = self._line
comment_start_size = len(comment_start)
- comment_end = self._COMMENTS[comment_start] # type: ignore
+ comment_end = self._COMMENTS[comment_start]
if comment_end:
- comment_end_size = len(comment_end)
+ # Skip the comment's start delimiter
+ self._advance(comment_start_size)
+ comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
- self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
+ self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
- self._comments.append(self._text[comment_start_size:]) # type: ignore
+ self._comments.append(self._text[comment_start_size:])
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
- if comment_start_line == self._prev_token_line or self._end:
+ if comment_start_line == self._prev_token_line:
self.tokens[-1].comments.extend(self._comments)
self._comments = []
self._prev_token_line = self._line
@@ -958,7 +986,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan_number(self) -> None:
if self._char == "0":
- peek = self._peek.upper() # type: ignore
+ peek = self._peek.upper()
if peek == "B":
return self._scan_bits()
elif peek == "X":
@@ -968,7 +996,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0
while True:
- if self._peek.isdigit(): # type: ignore
+ if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
decimal = True
@@ -976,24 +1004,23 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
- elif self._peek.upper() == "E" and not scientific: # type: ignore
+ elif self._peek.upper() == "E" and not scientific:
scientific += 1
self._advance()
- elif self._peek.isidentifier(): # type: ignore
+ elif self._peek.isidentifier():
number_text = self._text
- literal = []
+ literal = ""
- while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
- literal.append(self._peek.upper()) # type: ignore
+ while self._peek.strip() and self._peek not in self.SINGLE_TOKENS:
+ literal += self._peek.upper()
self._advance()
- literal = "".join(literal) # type: ignore
- token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
- return self._add(token_type, literal) # type: ignore
+ return self._add(token_type, literal)
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
@@ -1020,7 +1047,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _extract_value(self) -> str:
while True:
- char = self._peek.strip() # type: ignore
+ char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@@ -1029,35 +1056,35 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
def _scan_string(self, quote: str) -> bool:
- quote_end = self._QUOTES.get(quote) # type: ignore
+ quote_end = self._QUOTES.get(quote)
if quote_end is None:
return False
self._advance(len(quote))
text = self._extract_string(quote_end)
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
+ text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
- if string_start in self._HEX_STRINGS: # type: ignore
- delimiters = self._HEX_STRINGS # type: ignore
+ if string_start in self._HEX_STRINGS:
+ delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
base = 16
- elif string_start in self._BIT_STRINGS: # type: ignore
- delimiters = self._BIT_STRINGS # type: ignore
+ elif string_start in self._BIT_STRINGS:
+ delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
- elif string_start in self._BYTE_STRINGS: # type: ignore
- delimiters = self._BYTE_STRINGS # type: ignore
+ elif string_start in self._BYTE_STRINGS:
+ delimiters = self._BYTE_STRINGS
token_type = TokenType.BYTE_STRING
base = None
else:
return False
self._advance(len(string_start))
- string_end = delimiters.get(string_start)
+ string_end = delimiters[string_start]
text = self._extract_string(string_end)
if base is None:
@@ -1083,20 +1110,20 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
- text += identifier_end # type: ignore
+ text += identifier_end
self._advance()
continue
break
- text += self._char # type: ignore
+ text += self._char
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
- char = self._peek.strip() # type: ignore
- if char and char not in self.SINGLE_TOKENS:
+ char = self._peek.strip()
+ if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
self._advance()
else:
break
@@ -1115,9 +1142,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._peek == delimiter or self._peek in self._STRING_ESCAPES
):
if self._peek == delimiter:
- text += self._peek # type: ignore
+ text += self._peek
else:
- text += self._char + self._peek # type: ignore
+ text += self._char + self._peek
if self._current + 1 < self.size:
self._advance(2)
@@ -1131,7 +1158,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
- text += self._char # type: ignore
+ text += self._char
self._advance()
return text
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 62728d5..00f278e 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -103,7 +103,11 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr.copy(), alias), copy=False)
- expr.replace(exp.column(alias))
+ column = exp.column(alias)
+ if isinstance(expr.parent, exp.Qualify):
+ qualify_filters = column
+ else:
+ expr.replace(column)
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
@@ -133,9 +137,111 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
)
+def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
+ """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
+ if isinstance(expression, exp.Select):
+ for join in expression.args.get("joins") or []:
+ unnest = join.this
+
+ if isinstance(unnest, exp.Unnest):
+ alias = unnest.args.get("alias")
+ udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
+
+ expression.args["joins"].remove(join)
+
+ for e, column in zip(unnest.expressions, alias.columns if alias else []):
+ expression.append(
+ "laterals",
+ exp.Lateral(
+ this=udtf(this=e),
+ view=True,
+ alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
+ ),
+ )
+ return expression
+
+
+def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
+ """Convert explode/posexplode into unnest (used in hive -> presto)."""
+ if isinstance(expression, exp.Select):
+ from sqlglot.optimizer.scope import build_scope
+
+ taken_select_names = set(expression.named_selects)
+ taken_source_names = set(build_scope(expression).selected_sources)
+
+ for select in expression.selects:
+ to_replace = select
+
+ pos_alias = ""
+ explode_alias = ""
+
+ if isinstance(select, exp.Alias):
+ explode_alias = select.alias
+ select = select.this
+ elif isinstance(select, exp.Aliases):
+ pos_alias = select.aliases[0].name
+ explode_alias = select.aliases[1].name
+ select = select.this
+
+ if isinstance(select, (exp.Explode, exp.Posexplode)):
+ is_posexplode = isinstance(select, exp.Posexplode)
+
+ explode_arg = select.this
+ unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
+
+ # This ensures that we won't use [POS]EXPLODE's argument as a new selection
+ if isinstance(explode_arg, exp.Column):
+ taken_select_names.add(explode_arg.output_name)
+
+ unnest_source_alias = find_new_name(taken_source_names, "_u")
+ taken_source_names.add(unnest_source_alias)
+
+ if not explode_alias:
+ explode_alias = find_new_name(taken_select_names, "col")
+ taken_select_names.add(explode_alias)
+
+ if is_posexplode:
+ pos_alias = find_new_name(taken_select_names, "pos")
+ taken_select_names.add(pos_alias)
+
+ if is_posexplode:
+ column_names = [explode_alias, pos_alias]
+ to_replace.pop()
+ expression.select(pos_alias, explode_alias, copy=False)
+ else:
+ column_names = [explode_alias]
+ to_replace.replace(exp.column(explode_alias))
+
+ unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
+
+ if not expression.args.get("from"):
+ expression.from_(unnest, copy=False)
+ else:
+ expression.join(unnest, join_type="CROSS", copy=False)
+
+ return expression
+
+
+def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
+ """Remove table refs from columns in when statements."""
+ if isinstance(expression, exp.Merge):
+ alias = expression.this.args.get("alias")
+ targets = {expression.this.this}
+ if alias:
+ targets.add(alias.this)
+
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.name)
+ if isinstance(node, exp.Column) and node.args.get("table") in targets
+ else node,
+ copy=False,
+ )
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
- to_sql: t.Callable[[Generator, exp.Expression], str],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
@@ -143,36 +249,23 @@ def preprocess(
Args:
transforms: sequence of transform functions. These will be called in order.
- to_sql: final transform that converts the resulting expression to a SQL string.
Returns:
Function that can be used as a generator transform.
"""
- def _to_sql(self, expression):
+ def _to_sql(self, expression: exp.Expression) -> str:
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
- return to_sql(self, expression)
+ return getattr(self, expression.key + "_sql")(expression)
return _to_sql
-def delegate(attr: str) -> t.Callable:
- """
- Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
- functions that delegate to existing generator methods.
- """
-
- def _transform(self, *args, **kwargs):
- return getattr(self, attr)(*args, **kwargs)
-
- return _transform
-
-
-UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
-ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
-ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))}
+UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
+ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
+ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
- exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
+ exp.Cast: preprocess([remove_precision_parameterized_types])
}
diff --git a/sqlglot/trie.py b/sqlglot/trie.py
index f3b1c38..eba91b9 100644
--- a/sqlglot/trie.py
+++ b/sqlglot/trie.py
@@ -3,7 +3,7 @@ import typing as t
key = t.Sequence[t.Hashable]
-def new_trie(keywords: t.Iterable[key]) -> t.Dict:
+def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict:
"""
Creates a new trie out of a collection of keywords.
@@ -16,11 +16,12 @@ def new_trie(keywords: t.Iterable[key]) -> t.Dict:
Args:
keywords: the keywords to create the trie from.
+ trie: a trie to mutate instead of creating a new one
Returns:
The trie corresponding to `keywords`.
"""
- trie: t.Dict = {}
+ trie = {} if trie is None else trie
for key in keywords:
current = trie