From 639a208fa57ea674d165c4837e96f3ae4d7e3e61 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 19 Feb 2023 14:45:09 +0100 Subject: Merging upstream version 11.1.3. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/column.py | 1 + sqlglot/dataframe/sql/functions.py | 235 ++++++++++-------- sqlglot/dialects/bigquery.py | 30 ++- sqlglot/dialects/databricks.py | 9 + sqlglot/dialects/dialect.py | 22 +- sqlglot/dialects/drill.py | 21 +- sqlglot/dialects/duckdb.py | 20 +- sqlglot/dialects/hive.py | 14 +- sqlglot/dialects/mysql.py | 12 +- sqlglot/dialects/oracle.py | 62 ++++- sqlglot/dialects/postgres.py | 17 +- sqlglot/dialects/presto.py | 7 +- sqlglot/dialects/redshift.py | 12 +- sqlglot/dialects/snowflake.py | 36 ++- sqlglot/dialects/spark.py | 17 +- sqlglot/dialects/sqlite.py | 2 +- sqlglot/dialects/teradata.py | 45 +++- sqlglot/dialects/tsql.py | 6 +- sqlglot/executor/__init__.py | 4 +- sqlglot/executor/python.py | 2 +- sqlglot/expressions.py | 233 +++++++++++++++--- sqlglot/generator.py | 264 +++++++++++--------- sqlglot/optimizer/annotate_types.py | 3 + sqlglot/optimizer/eliminate_subqueries.py | 6 +- sqlglot/optimizer/optimizer.py | 31 ++- sqlglot/optimizer/pushdown_projections.py | 76 +++--- sqlglot/optimizer/qualify_columns.py | 48 ++-- sqlglot/optimizer/qualify_tables.py | 2 +- sqlglot/optimizer/scope.py | 169 ++++++++----- sqlglot/parser.py | 390 +++++++++++++++++++++--------- sqlglot/tokens.py | 40 ++- 32 files changed, 1213 insertions(+), 625 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 7b07ae1..c17a703 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -40,7 +40,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.0.1" +__version__ = "11.1.3" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f5b0974..609b2a4 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -143,6 +143,7 @@ class Column: if is_iterable(v) else Column.ensure_col(v).expression for k, v in kwargs.items() + if v is not None } new_expression = ( callable_expression(**ensure_expression_values) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 47d5e7b..0262d54 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import expressions as glotexp +from sqlglot import exp as expression from sqlglot.dataframe.sql.column import Column from sqlglot.helper import ensure_list from sqlglot.helper import flatten as _flatten @@ -18,25 +18,29 @@ def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column: def lit(value: t.Optional[t.Any] = None) -> Column: if isinstance(value, str): - return Column(glotexp.Literal.string(str(value))) + return Column(expression.Literal.string(str(value))) return Column(value) def greatest(*cols: ColumnOrName) -> Column: if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], glotexp.Greatest) + return Column.invoke_expression_over_column( + cols[0], expression.Greatest, expressions=cols[1:] + ) + return Column.invoke_expression_over_column(cols[0], expression.Greatest) def least(*cols: ColumnOrName) -> Column: if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], glotexp.Least) + return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:]) + return Column.invoke_expression_over_column(cols[0], expression.Least) def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: columns = [Column.ensure_col(x) for x in [col] + list(cols)] - return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns]))) + return Column( + expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns])) + ) def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: @@ -46,8 +50,8 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def when(condition: Column, value: t.Any) -> Column: true_value = value if isinstance(value, Column) else lit(value) return Column( - glotexp.Case( - ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)] + expression.Case( + ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)] ) ) @@ -65,19 +69,19 @@ def broadcast(df: DataFrame) -> DataFrame: def sqrt(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Sqrt) + return Column.invoke_expression_over_column(col, expression.Sqrt) def abs(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Abs) + return Column.invoke_expression_over_column(col, expression.Abs) def max(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Max) + return Column.invoke_expression_over_column(col, expression.Max) def min(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Min) + return Column.invoke_expression_over_column(col, expression.Min) def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: @@ -89,15 +93,15 @@ def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: def count(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Count) + return Column.invoke_expression_over_column(col, expression.Count) def sum(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Sum) + return Column.invoke_expression_over_column(col, expression.Sum) def avg(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Avg) + return Column.invoke_expression_over_column(col, expression.Avg) def mean(col: ColumnOrName) -> Column: @@ -149,7 +153,7 @@ def cbrt(col: ColumnOrName) -> Column: def ceil(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Ceil) + return Column.invoke_expression_over_column(col, expression.Ceil) def cos(col: ColumnOrName) -> Column: @@ -169,7 +173,7 @@ def csc(col: ColumnOrName) -> Column: def exp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Exp) + return Column.invoke_expression_over_column(col, expression.Exp) def expm1(col: ColumnOrName) -> Column: @@ -177,11 +181,11 @@ def expm1(col: ColumnOrName) -> Column: def floor(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Floor) + return Column.invoke_expression_over_column(col, expression.Floor) def log10(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Log10) + return Column.invoke_expression_over_column(col, expression.Log10) def log1p(col: ColumnOrName) -> Column: @@ -189,13 +193,13 @@ def log1p(col: ColumnOrName) -> Column: def log2(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Log2) + return Column.invoke_expression_over_column(col, expression.Log2) def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: if arg2 is None: - return Column.invoke_expression_over_column(arg1, glotexp.Ln) - return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2) + return Column.invoke_expression_over_column(arg1, expression.Ln) + return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2) def rint(col: ColumnOrName) -> Column: @@ -247,7 +251,7 @@ def bitwiseNOT(col: ColumnOrName) -> Column: def bitwise_not(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.BitwiseNot) + return Column.invoke_expression_over_column(col, expression.BitwiseNot) def asc_nulls_first(col: ColumnOrName) -> Column: @@ -267,27 +271,27 @@ def desc_nulls_last(col: ColumnOrName) -> Column: def stddev(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Stddev) + return Column.invoke_expression_over_column(col, expression.Stddev) def stddev_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.StddevSamp) + return Column.invoke_expression_over_column(col, expression.StddevSamp) def stddev_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.StddevPop) + return Column.invoke_expression_over_column(col, expression.StddevPop) def variance(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Variance) + return Column.invoke_expression_over_column(col, expression.Variance) def var_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Variance) + return Column.invoke_expression_over_column(col, expression.Variance) def var_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.VariancePop) + return Column.invoke_expression_over_column(col, expression.VariancePop) def skewness(col: ColumnOrName) -> Column: @@ -299,11 +303,11 @@ def kurtosis(col: ColumnOrName) -> Column: def collect_list(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.ArrayAgg) + return Column.invoke_expression_over_column(col, expression.ArrayAgg) def collect_set(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.SetAgg) + return Column.invoke_expression_over_column(col, expression.SetAgg) def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: @@ -311,27 +315,27 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float] def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2) + return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2) def row_number() -> Column: - return Column(glotexp.Anonymous(this="ROW_NUMBER")) + return Column(expression.Anonymous(this="ROW_NUMBER")) def dense_rank() -> Column: - return Column(glotexp.Anonymous(this="DENSE_RANK")) + return Column(expression.Anonymous(this="DENSE_RANK")) def rank() -> Column: - return Column(glotexp.Anonymous(this="RANK")) + return Column(expression.Anonymous(this="RANK")) def cume_dist() -> Column: - return Column(glotexp.Anonymous(this="CUME_DIST")) + return Column(expression.Anonymous(this="CUME_DIST")) def percent_rank() -> Column: - return Column(glotexp.Anonymous(this="PERCENT_RANK")) + return Column(expression.Anonymous(this="PERCENT_RANK")) def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: @@ -340,14 +344,16 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: if rsd is None: - return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct) - return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd) + return Column.invoke_expression_over_column(col, expression.ApproxDistinct) + return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd) def coalesce(*cols: ColumnOrName) -> Column: if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce) + return Column.invoke_expression_over_column( + cols[0], expression.Coalesce, expressions=cols[1:] + ) + return Column.invoke_expression_over_column(cols[0], expression.Coalesce) def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: @@ -409,10 +415,10 @@ def percentile_approx( ) -> Column: if accuracy: return Column.invoke_expression_over_column( - col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy + col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy ) return Column.invoke_expression_over_column( - col, glotexp.ApproxQuantile, quantile=lit(percentage) + col, expression.ApproxQuantile, quantile=lit(percentage) ) @@ -426,8 +432,8 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: if scale is not None: - return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale) - return Column.invoke_expression_over_column(col, glotexp.Round) + return Column.invoke_expression_over_column(col, expression.Round, decimals=scale) + return Column.invoke_expression_over_column(col, expression.Round) def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: @@ -437,7 +443,9 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: def shiftleft(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits) + return Column.invoke_expression_over_column( + col, expression.BitwiseLeftShift, expression=numBits + ) def shiftLeft(col: ColumnOrName, numBits: int) -> Column: @@ -445,7 +453,9 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column: def shiftright(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits) + return Column.invoke_expression_over_column( + col, expression.BitwiseRightShift, expression=numBits + ) def shiftRight(col: ColumnOrName, numBits: int) -> Column: @@ -466,7 +476,7 @@ def expr(str: str) -> Column: def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: columns = ensure_list(col) + list(cols) - return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns) + return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns) def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: @@ -512,19 +522,19 @@ def ntile(n: int) -> Column: def current_date() -> Column: - return Column.invoke_expression_over_column(None, glotexp.CurrentDate) + return Column.invoke_expression_over_column(None, expression.CurrentDate) def current_timestamp() -> Column: - return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp) + return Column.invoke_expression_over_column(None, expression.CurrentTimestamp) def date_format(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format)) + return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format)) def year(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Year) + return Column.invoke_expression_over_column(col, expression.Year) def quarter(col: ColumnOrName) -> Column: @@ -532,19 +542,19 @@ def quarter(col: ColumnOrName) -> Column: def month(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Month) + return Column.invoke_expression_over_column(col, expression.Month) def dayofweek(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DayOfWeek) + return Column.invoke_expression_over_column(col, expression.DayOfWeek) def dayofmonth(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DayOfMonth) + return Column.invoke_expression_over_column(col, expression.DayOfMonth) def dayofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DayOfYear) + return Column.invoke_expression_over_column(col, expression.DayOfYear) def hour(col: ColumnOrName) -> Column: @@ -560,7 +570,7 @@ def second(col: ColumnOrName) -> Column: def weekofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.WeekOfYear) + return Column.invoke_expression_over_column(col, expression.WeekOfYear) def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: @@ -568,15 +578,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days) + return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days) def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days) + return Column.invoke_expression_over_column(col, expression.DateSub, expression=days) def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start) + return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start) def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: @@ -593,8 +603,10 @@ def months_between( def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format)) - return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate) + return Column.invoke_expression_over_column( + col, expression.TsOrDsToDate, format=lit(format) + ) + return Column.invoke_expression_over_column(col, expression.TsOrDsToDate) def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: @@ -604,11 +616,13 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def trunc(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format)) + return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format)) def date_trunc(format: str, timestamp: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format)) + return Column.invoke_expression_over_column( + timestamp, expression.TimestampTrunc, unit=lit(format) + ) def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: @@ -621,8 +635,8 @@ def last_day(col: ColumnOrName) -> Column: def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format)) - return Column.invoke_expression_over_column(col, glotexp.UnixToStr) + return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format)) + return Column.invoke_expression_over_column(col, expression.UnixToStr) def unix_timestamp( @@ -630,9 +644,9 @@ def unix_timestamp( ) -> Column: if format is not None: return Column.invoke_expression_over_column( - timestamp, glotexp.StrToUnix, format=lit(format) + timestamp, expression.StrToUnix, format=lit(format) ) - return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) + return Column.invoke_expression_over_column(timestamp, expression.StrToUnix) def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: @@ -719,11 +733,11 @@ def raise_error(errorMsg: ColumnOrName) -> Column: def upper(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Upper) + return Column.invoke_expression_over_column(col, expression.Upper) def lower(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Lower) + return Column.invoke_expression_over_column(col, expression.Lower) def ascii(col: ColumnOrLiteral) -> Column: @@ -747,24 +761,24 @@ def rtrim(col: ColumnOrName) -> Column: def trim(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Trim) + return Column.invoke_expression_over_column(col, expression.Trim) def concat_ws(sep: str, *cols: ColumnOrName) -> Column: return Column.invoke_expression_over_column( - None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols) + None, expression.ConcatWs, expressions=[lit(sep)] + list(cols) ) def decode(col: ColumnOrName, charset: str) -> Column: return Column.invoke_expression_over_column( - col, glotexp.Decode, charset=glotexp.Literal.string(charset) + col, expression.Decode, charset=expression.Literal.string(charset) ) def encode(col: ColumnOrName, charset: str) -> Column: return Column.invoke_expression_over_column( - col, glotexp.Encode, charset=glotexp.Literal.string(charset) + col, expression.Encode, charset=expression.Literal.string(charset) ) @@ -816,16 +830,16 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right) + return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right) def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) if pos is not None: return Column.invoke_expression_over_column( - str, glotexp.StrPosition, substr=substr_col, position=pos + str, expression.StrPosition, substr=substr_col, position=pos ) - return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) + return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col) def lpad(col: ColumnOrName, len: int, pad: str) -> Column: @@ -837,21 +851,26 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column: def repeat(col: ColumnOrName, n: int) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n)) + return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n)) def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: if limit is not None: return Column.invoke_expression_over_column( - str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit + str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit ) - return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern)) + return Column.invoke_expression_over_column( + str, expression.RegexpSplit, expression=lit(pattern) + ) def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: - if idx is not None: - return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx) - return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern)) + return Column.invoke_expression_over_column( + str, + expression.RegexpExtract, + expression=lit(pattern), + group=idx, + ) def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: @@ -859,7 +878,7 @@ def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: def initcap(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Initcap) + return Column.invoke_expression_over_column(col, expression.Initcap) def soundex(col: ColumnOrName) -> Column: @@ -871,15 +890,15 @@ def bin(col: ColumnOrName) -> Column: def hex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Hex) + return Column.invoke_expression_over_column(col, expression.Hex) def unhex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Unhex) + return Column.invoke_expression_over_column(col, expression.Unhex) def length(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Length) + return Column.invoke_expression_over_column(col, expression.Length) def octet_length(col: ColumnOrName) -> Column: @@ -896,27 +915,27 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols - return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns) + return Column.invoke_expression_over_column(None, expression.Array, expressions=columns) def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore return Column.invoke_expression_over_column( None, - glotexp.VarMap, + expression.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression, ) def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2) + return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2) def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value_col = value if isinstance(value, Column) else lit(value) return Column.invoke_expression_over_column( - col, glotexp.ArrayContains, expression=value_col.expression + col, expression.ArrayContains, expression=value_col.expression ) @@ -943,7 +962,7 @@ def array_join( def concat(*cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols) + return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -978,11 +997,11 @@ def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: def explode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Explode) + return Column.invoke_expression_over_column(col, expression.Explode) def posexplode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Posexplode) + return Column.invoke_expression_over_column(col, expression.Posexplode) def explode_outer(col: ColumnOrName) -> Column: @@ -994,7 +1013,7 @@ def posexplode_outer(col: ColumnOrName) -> Column: def get_json_object(col: ColumnOrName, path: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path)) + return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path)) def json_tuple(col: ColumnOrName, *fields: str) -> Column: @@ -1042,7 +1061,7 @@ def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> C def size(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.ArraySize) + return Column.invoke_expression_over_column(col, expression.ArraySize) def array_min(col: ColumnOrName) -> Column: @@ -1055,8 +1074,8 @@ def array_max(col: ColumnOrName) -> Column: def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: if asc is not None: - return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc) - return Column.invoke_expression_over_column(col, glotexp.SortArray) + return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc) + return Column.invoke_expression_over_column(col, expression.SortArray) def array_sort( @@ -1065,8 +1084,10 @@ def array_sort( ) -> Column: if comparator is not None: f_expression = _get_lambda_from_func(comparator) - return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression) - return Column.invoke_expression_over_column(col, glotexp.ArraySort) + return Column.invoke_expression_over_column( + col, expression.ArraySort, expression=f_expression + ) + return Column.invoke_expression_over_column(col, expression.ArraySort) def shuffle(col: ColumnOrName) -> Column: @@ -1146,13 +1167,13 @@ def aggregate( finish_exp = _get_lambda_from_func(finish) return Column.invoke_expression_over_column( col, - glotexp.Reduce, + expression.Reduce, initial=initialValue, merge=Column(merge_exp), finish=Column(finish_exp), ) return Column.invoke_expression_over_column( - col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp) + col, expression.Reduce, initial=initialValue, merge=Column(merge_exp) ) @@ -1179,7 +1200,9 @@ def filter( f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) - return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) + return Column.invoke_expression_over_column( + col, expression.ArrayFilter, expression=f_expression + ) def zip_with( @@ -1219,10 +1242,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]: def _get_lambda_from_func(lambda_expression: t.Callable): variables = [ - glotexp.to_identifier(x, quoted=_lambda_quoted(x)) + expression.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames ] - return glotexp.Lambda( + return expression.Lambda( this=lambda_expression(*[Column(x) for x in variables]).expression, expressions=variables, ) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6a19b46..7fd9e35 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re import typing as t from sqlglot import exp, generator, parser, tokens, transforms @@ -31,13 +32,6 @@ def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: return func -def _date_trunc(args: t.Sequence) -> exp.Expression: - unit = seq_get(args, 1) - if isinstance(unit, exp.Column): - unit = exp.Var(this=unit.name) - return exp.DateTrunc(this=seq_get(args, 0), expression=unit) - - def _date_add_sql( data_type: str, kind: str ) -> t.Callable[[generator.Generator, exp.Expression], str]: @@ -158,11 +152,23 @@ class BigQuery(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "DATE_TRUNC": _date_trunc, + "DATE_TRUNC": lambda args: exp.DateTrunc( + unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore + this=seq_get(args, 0), + ), "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), + expression=seq_get(args, 1), + position=seq_get(args, 2), + occurrence=seq_get(args, 3), + group=exp.Literal.number(1) + if re.compile(str(seq_get(args, 1))).groups == 1 + else None, + ), "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), @@ -214,6 +220,7 @@ class BigQuery(Dialect): exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateStrToDate: datestrtodate_sql, + exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), @@ -226,11 +233,12 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.VariancePop: rename_func("VAR_POP"), exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, - exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -251,6 +259,10 @@ class BigQuery(Dialect): exp.DataType.Type.VARCHAR: "STRING", exp.DataType.Type.NVARCHAR: "STRING", } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + } EXPLICIT_UNION = True diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2498c62..2e058e8 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -4,6 +4,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.spark import Spark from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql +from sqlglot.tokens import TokenType class Databricks(Spark): @@ -21,3 +22,11 @@ class Databricks(Spark): exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, } + + PARAMETER_TOKEN = "$" + + class Tokenizer(Spark.Tokenizer): + SINGLE_TOKENS = { + **Spark.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 176a8ce..f4e8fd4 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -215,24 +215,19 @@ DialectType = t.Union[str, Dialect, t.Type[Dialect], None] def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: - def _rename(self, expression): - args = flatten(expression.args.values()) - return f"{self.normalize_func(name)}({self.format_args(*args)})" - - return _rename + return lambda self, expression: self.func(name, *flatten(expression.args.values())) def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: if expression.args.get("accuracy"): self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") - return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" + return self.func("APPROX_COUNT_DISTINCT", expression.this) def if_sql(self: Generator, expression: exp.If) -> str: - expressions = self.format_args( - expression.this, expression.args.get("true"), expression.args.get("false") + return self.func( + "IF", expression.this, expression.args.get("true"), expression.args.get("false") ) - return f"IF({expressions})" def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: @@ -318,13 +313,13 @@ def var_map_sql( if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): self.unsupported("Cannot convert array columns into map.") - return f"{map_func_name}({self.format_args(keys, values)})" + return self.func(map_func_name, keys, values) args = [] for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) - return f"{map_func_name}({self.format_args(*args)})" + return self.func(map_func_name, *args) def format_time_lambda( @@ -400,10 +395,9 @@ def locate_to_strposition(args: t.Sequence) -> exp.Expression: def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: - args = self.format_args( - expression.args.get("substr"), expression.this, expression.args.get("position") + return self.func( + "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") ) - return f"LOCATE({args})" def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 1730eaf..e9c42e1 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -39,23 +39,6 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e return func -def if_sql(self: generator.Generator, expression: exp.If) -> str: - """ - Drill requires backticks around certain SQL reserved words, IF being one of them, This function - adds the backticks around the keyword IF. - Args: - self: The Drill dialect - expression: The input IF expression - - Returns: The expression with IF in backticks. - - """ - expressions = self.format_args( - expression.this, expression.args.get("true"), expression.args.get("false") - ) - return f"`IF`({expressions})" - - def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) @@ -134,7 +117,7 @@ class Drill(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, } TRANSFORMS = { @@ -148,7 +131,7 @@ class Drill(Dialect): exp.DateSub: _date_add_sql("SUB"), exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", - exp.If: if_sql, + exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 959e5e2..cfec9a4 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -73,11 +73,24 @@ def _datatype_sql(self, expression): return self.datatype_sql(expression) +def _regexp_extract_sql(self, expression): + bad_args = list(filter(expression.args.get, ("position", "occurrence"))) + if bad_args: + self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}") + return self.func( + "REGEXP_EXTRACT", + expression.args.get("this"), + expression.args.get("expression"), + expression.args.get("group"), + ) + + class DuckDB(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, + "ATTACH": TokenType.COMMAND, "CHARACTER VARYING": TokenType.VARCHAR, } @@ -117,7 +130,7 @@ class DuckDB(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) if isinstance(seq_get(e.expressions, 0), exp.Select) else rename_func("LIST_VALUE")(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -125,7 +138,9 @@ class DuckDB(Dialect): exp.ArraySum: rename_func("LIST_SUM"), exp.DataType: _datatype_sql, exp.DateAdd: _date_add, - exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""", + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this + ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", @@ -137,6 +152,7 @@ class DuckDB(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, + exp.RegexpExtract: _regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c558b70..ea1191e 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -43,7 +43,7 @@ def _add_date_sql(self, expression): else expression.expression ) modified_increment = exp.Literal.number(modified_increment) - return f"{func}({self.format_args(expression.this, modified_increment.this)})" + return self.func(func, expression.this, modified_increment.this) def _date_diff_sql(self, expression): @@ -66,7 +66,7 @@ def _property_sql(self, expression): def _str_to_unix(self, expression): - return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})" + return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) def _str_to_date(self, expression): @@ -312,7 +312,9 @@ class Hive(Dialect): exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsToDate: _to_date_sql, exp.TryCast: no_trycast_sql, - exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", + exp.UnixToStr: lambda self, e: self.func( + "FROM_UNIXTIME", e.this, _time_format(self, e) + ), exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", @@ -324,9 +326,9 @@ class Hive(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } def with_properties(self, properties): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index c2c2c8c..235eb77 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + rename_func, strposition_to_locate_sql, ) from sqlglot.helper import seq_get @@ -22,9 +23,8 @@ def _show_parser(*args, **kwargs): def _date_trunc_sql(self, expression): - unit = expression.name.lower() - - expr = self.sql(expression.expression) + expr = self.sql(expression, "this") + unit = expression.text("unit") if unit == "day": return f"DATE({expr})" @@ -42,7 +42,7 @@ def _date_trunc_sql(self, expression): concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: - self.unsupported("Unexpected interval unit: {unit}") + self.unsupported(f"Unexpected interval unit: {unit}") return f"DATE({expr})" return f"STR_TO_DATE({concat}, '{date_format}')" @@ -443,6 +443,10 @@ class MySQL(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index fde845e..74baa8a 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,15 +1,49 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql from sqlglot.helper import csv from sqlglot.tokens import TokenType +PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { + TokenType.COLUMN, + TokenType.RETURNING, +} + def _limit_sql(self, expression): return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) +def _parse_xml_table(self) -> exp.XMLTable: + this = self._parse_string() + + passing = None + columns = None + + if self._match_text_seq("PASSING"): + # The BY VALUE keywords are optional and are provided for semantic clarity + self._match_text_seq("BY", "VALUE") + passing = self._parse_csv( + lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS) + ) + + by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") + + if self._match_text_seq("COLUMNS"): + columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True))) + + return self.expression( + exp.XMLTable, + this=this, + passing=passing, + columns=columns, + by_ref=by_ref, + ) + + class Oracle(Dialect): # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes @@ -43,6 +77,11 @@ class Oracle(Dialect): "DECODE": exp.Matches.from_arg_list, } + FUNCTION_PARSERS: t.Dict[str, t.Callable] = { + **parser.Parser.FUNCTION_PARSERS, + "XMLTABLE": _parse_xml_table, + } + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -74,7 +113,7 @@ class Oracle(Dialect): exp.Substring: rename_func("SUBSTR"), } - def query_modifiers(self, expression, *sqls): + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("joins") or []], @@ -97,19 +136,32 @@ class Oracle(Dialect): sep="", ) - def offset_sql(self, expression): + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" - def table_sql(self, expression): - return super().table_sql(expression, sep=" ") + def table_sql(self, expression: exp.Table, sep: str = " ") -> str: + return super().table_sql(expression, sep=sep) + + def xmltable_sql(self, expression: exp.XMLTable) -> str: + this = self.sql(expression, "this") + passing = self.expressions(expression, "passing") + passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else "" + columns = self.expressions(expression, "columns") + columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else "" + by_ref = ( + f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else "" + ) + return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}" class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "COLUMNS": TokenType.COLUMN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, + "NVARCHAR2": TokenType.NVARCHAR, + "RETURNING": TokenType.RETURNING, "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, - "NVARCHAR2": TokenType.NVARCHAR, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c709665..7612330 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -58,17 +58,17 @@ def _date_diff_sql(self, expression): age = f"AGE({end}, {start})" if unit == "WEEK": - extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" elif unit == "MONTH": - extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" + unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" elif unit == "QUARTER": - extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" + unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" elif unit == "YEAR": - extract = f"EXTRACT(year FROM {age})" + unit = f"EXTRACT(year FROM {age})" else: - self.unsupported(f"Unsupported DATEDIFF unit {unit}") + unit = age - return f"CAST({extract} AS BIGINT)" + return f"CAST({unit} AS BIGINT)" def _substring_sql(self, expression): @@ -206,6 +206,8 @@ class Postgres(Dialect): } class Tokenizer(tokens.Tokenizer): + QUOTES = ["'", "$$"] + BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] @@ -236,7 +238,7 @@ class Postgres(Dialect): "UUID": TokenType.UUID, "CSTRING": TokenType.PSEUDO_TYPE, } - QUOTES = ["'", "$$"] + SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, @@ -265,6 +267,7 @@ class Postgres(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + PARAMETER_TOKEN = "$" TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6c1a474..aef9de3 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -52,7 +52,7 @@ def _initcap_sql(self, expression): def _decode_sql(self, expression): _ensure_utf8(expression.args.get("charset")) - return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})" + return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) def _encode_sql(self, expression): @@ -65,8 +65,7 @@ def _no_sort_array(self, expression): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: comparator = None - args = self.format_args(expression.this, comparator) - return f"ARRAY_SORT({args})" + return self.func("ARRAY_SORT", expression.this, comparator) def _schema_sql(self, expression): @@ -125,7 +124,7 @@ def _sequence_sql(self, expression): else: start = exp.Cast(this=start, to=to) - return f"SEQUENCE({self.format_args(start, end, step)})" + return self.func("SEQUENCE", start, end, step) def _ensure_utf8(charset): diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 813ee5f..b4268e6 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -19,6 +20,11 @@ class Redshift(Postgres): class Parser(Postgres.Parser): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, # type: ignore + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DECODE": exp.Matches.from_arg_list, "NVL": exp.Coalesce.from_arg_list, } @@ -41,7 +47,6 @@ class Redshift(Postgres): KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore - "ENCODE": TokenType.ENCODE, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, @@ -62,12 +67,15 @@ class Redshift(Postgres): PROPERTIES_LOCATION = { **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore - exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.LikeProperty: exp.Properties.Location.POST_WITH, } TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DateDiff: lambda self, e: self.func( + "DATEDIFF", e.args.get("unit") or "day", e.expression, e.this + ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 55a6bd3..bb46135 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -178,18 +178,25 @@ class Snowflake(Dialect): ), } + RANGE_PARSERS = { + **parser.Parser.RANGE_PARSERS, # type: ignore + TokenType.LIKE_ANY: lambda self, this: self._parse_escape( + self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise()) + ), + TokenType.ILIKE_ANY: lambda self, this: self._parse_escape( + self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise()) + ), + } + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } - KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "EXCLUDE": TokenType.EXCEPT, + "ILIKE ANY": TokenType.ILIKE_ANY, + "LIKE ANY": TokenType.LIKE_ANY, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, @@ -201,8 +208,14 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, } + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } + class Generator(generator.Generator): CREATE_TRANSIENT = True + PARAMETER_TOKEN = "$" TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -214,14 +227,15 @@ class Snowflake(Dialect): exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), - exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", + exp.StrPosition: lambda self, e: self.func( + "POSITION", e.args.get("substr"), e.this, e.args.get("position") + ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", + exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), } @@ -236,6 +250,12 @@ class Snowflake(Dialect): "replace": "RENAME", } + def ilikeany_sql(self, expression: exp.ILikeAny) -> str: + return self.binary(expression, "ILIKE ANY") + + def likeany_sql(self, expression: exp.LikeAny) -> str: + return self.binary(expression, "LIKE ANY") + def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 03ec211..dd3e0c8 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -86,6 +86,11 @@ class Spark(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), + unit=exp.var(seq_get(args, 0)), + ), + "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), } FUNCTION_PARSERS = { @@ -133,7 +138,7 @@ class Spark(Hive): exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.DateTrunc: rename_func("TRUNC"), + exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -142,7 +147,9 @@ class Spark(Hive): exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", - exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this + ), exp.Trim: trim_sql, exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), @@ -157,16 +164,16 @@ class Spark(Hive): TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False - CREATE_FUNCTION_AS = False + CREATE_FUNCTION_RETURN_AS = False def cast_sql(self, expression: exp.Cast) -> str: if isinstance(expression.this, exp.Cast) and expression.this.is_type( exp.DataType.Type.JSON ): schema = f"'{self.sql(expression, 'to')}'" - return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})" + return self.func("FROM_JSON", expression.this.this, schema) if expression.to.is_type(exp.DataType.Type.JSON): - return f"TO_JSON({self.sql(expression, 'this')})" + return self.func("TO_JSON", expression.this) return super(Spark.Generator, self).cast_sql(expression) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index a428dd5..86603b5 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -39,7 +39,7 @@ def _date_add_sql(self, expression): modifier = expression.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" - return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})" + return self.func("DATE", expression.this, modifier) class SQLite(Dialect): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 123da04..e3eec71 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,11 +1,33 @@ from __future__ import annotations -from sqlglot import exp, generator, parser +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect from sqlglot.tokens import TokenType class Teradata(Dialect): + class Tokenizer(tokens.Tokenizer): + # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "BYTEINT": TokenType.SMALLINT, + "SEL": TokenType.SELECT, + "INS": TokenType.INSERT, + "MOD": TokenType.MOD, + "LT": TokenType.LT, + "LE": TokenType.LTE, + "GT": TokenType.GT, + "GE": TokenType.GTE, + "^=": TokenType.NEQ, + "NE": TokenType.NEQ, + "NOT=": TokenType.NEQ, + "ST_GEOMETRY": TokenType.GEOMETRY, + } + + # teradata does not support % for modulus + SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS} + SINGLE_TOKENS.pop("%") + class Parser(parser.Parser): CHARSET_TRANSLATORS = { "GRAPHIC_TO_KANJISJIS", @@ -42,6 +64,14 @@ class Teradata(Dialect): "UNICODE_TO_UNICODE_NFKD", } + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS} + FUNC_TOKENS.remove(TokenType.REPLACE) + + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, # type: ignore + TokenType.REPLACE: lambda self: self._parse_create(), + } + FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, # type: ignore "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), @@ -76,6 +106,11 @@ class Teradata(Dialect): ) class Generator(generator.Generator): + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, # type: ignore + exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", + } + PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, @@ -93,3 +128,11 @@ class Teradata(Dialect): where_sql = self.sql(expression, "where") sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" return self.prepend_ctes(expression, sql) + + def mod_sql(self, expression: exp.Mod) -> str: + return self.binary(expression, "MOD") + + def datatype_sql(self, expression: exp.DataType) -> str: + type_sql = super().datatype_sql(expression) + prefix_sql = expression.args.get("prefix") + return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 05ba53a..b9f932b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -92,7 +92,7 @@ def _parse_eomonth(args): def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" - return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" + return self.func(func, e.text("unit"), e.expression, e.this) def _format_sql(self, e): @@ -101,7 +101,7 @@ def _format_sql(self, e): if isinstance(e, exp.NumberToStr) else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) ) - return f"FORMAT({self.format_args(e.this, fmt)})" + return self.func("FORMAT", e.this, fmt) def _string_agg_sql(self, e): @@ -408,7 +408,7 @@ class TSQL(Dialect): ): return this - expressions = self._parse_csv(self._parse_udf_kwarg) + expressions = self._parse_csv(self._parse_function_parameter) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) class Generator(generator.Generator): diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index c3d2701..a676e7d 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -62,10 +62,8 @@ def execute( if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args: raise ExecuteError("Tables must support the same table args as schema") - expression = maybe_parse(sql, dialect=read) - now = time.time() - expression = optimize(expression, schema, leave_tables_isolated=True) + expression = optimize(sql, schema, leave_tables_isolated=True, dialect=read) logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index de570b0..d417328 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -376,7 +376,7 @@ def _rename(self, e): this = self.sql(e, "this") this = f"{this}, " if this else "" return f"{e.key.upper()}({this}{self.expressions(e)})" - return f"{e.key.upper()}({self.format_args(*e.args.values())})" + return self.func(e.key, *e.args.values()) except Exception as ex: raise Exception(f"Could not rename {repr(e)}") from ex diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 6800cd5..42652a6 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -128,7 +128,7 @@ class Expression(metaclass=_Expression): """ return self.args.get("expressions") or [] - def text(self, key): + def text(self, key) -> str: """ Returns a textual representation of the argument corresponding to "key". This can only be used for args that are strings or leaf Expression instances, such as identifiers and literals. @@ -143,21 +143,21 @@ class Expression(metaclass=_Expression): return "" @property - def is_string(self): + def is_string(self) -> bool: """ Checks whether a Literal expression is a string. """ return isinstance(self, Literal) and self.args["is_string"] @property - def is_number(self): + def is_number(self) -> bool: """ Checks whether a Literal expression is a number. """ return isinstance(self, Literal) and not self.args["is_string"] @property - def is_int(self): + def is_int(self) -> bool: """ Checks whether a Literal expression is an integer. """ @@ -170,7 +170,12 @@ class Expression(metaclass=_Expression): return False @property - def alias(self): + def is_star(self) -> bool: + """Checks whether an expression is a star.""" + return isinstance(self, Star) or (isinstance(self, Column) and isinstance(self.this, Star)) + + @property + def alias(self) -> str: """ Returns the alias of the expression, or an empty string if it's not aliased. """ @@ -825,10 +830,6 @@ class UserDefinedFunction(Expression): arg_types = {"this": True, "expressions": False, "wrapped": False} -class UserDefinedFunctionKwarg(Expression): - arg_types = {"this": True, "kind": True, "default": False} - - class CharacterSet(Expression): arg_types = {"this": True, "default": False} @@ -870,14 +871,22 @@ class ByteString(Condition): class Column(Condition): - arg_types = {"this": True, "table": False} + arg_types = {"this": True, "table": False, "db": False, "catalog": False} @property - def table(self): + def table(self) -> str: return self.text("table") @property - def output_name(self): + def db(self) -> str: + return self.text("db") + + @property + def catalog(self) -> str: + return self.text("catalog") + + @property + def output_name(self) -> str: return self.name @@ -917,6 +926,14 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind): pass +class CaseSpecificColumnConstraint(ColumnConstraintKind): + arg_types = {"not_": True} + + +class CharacterSetColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True} + + class CheckColumnConstraint(ColumnConstraintKind): pass @@ -929,6 +946,10 @@ class CommentColumnConstraint(ColumnConstraintKind): pass +class DateFormatColumnConstraint(ColumnConstraintKind): + arg_types = {"this": True} + + class DefaultColumnConstraint(ColumnConstraintKind): pass @@ -939,7 +960,14 @@ class EncodeColumnConstraint(ColumnConstraintKind): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT - arg_types = {"this": False, "start": False, "increment": False} + arg_types = { + "this": False, + "start": False, + "increment": False, + "minvalue": False, + "maxvalue": False, + "cycle": False, + } class NotNullColumnConstraint(ColumnConstraintKind): @@ -950,7 +978,19 @@ class PrimaryKeyColumnConstraint(ColumnConstraintKind): arg_types = {"desc": False} +class TitleColumnConstraint(ColumnConstraintKind): + pass + + class UniqueColumnConstraint(ColumnConstraintKind): + arg_types: t.Dict[str, t.Any] = {} + + +class UppercaseColumnConstraint(ColumnConstraintKind): + arg_types: t.Dict[str, t.Any] = {} + + +class PathColumnConstraint(ColumnConstraintKind): pass @@ -1063,6 +1103,7 @@ class Insert(Expression): "overwrite": False, "exists": False, "partition": False, + "alternative": False, } @@ -1438,6 +1479,16 @@ class IsolatedLoadingProperty(Property): } +class LockingProperty(Property): + arg_types = { + "this": False, + "kind": True, + "for_or_in": True, + "lock_type": True, + "override": False, + } + + class Properties(Expression): arg_types = {"expressions": True} @@ -1463,12 +1514,26 @@ class Properties(Expression): PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + # CREATE property locations + # Form: schema specified + # create [POST_CREATE] + # table a [POST_NAME] + # (b int) [POST_SCHEMA] + # with ([POST_WITH]) + # index (b) [POST_INDEX] + # + # Form: alias selection + # create [POST_CREATE] + # table a [POST_NAME] + # as [POST_ALIAS] (select * from b) + # index (c) [POST_INDEX] class Location(AutoName): POST_CREATE = auto() - PRE_SCHEMA = auto() + POST_NAME = auto() + POST_SCHEMA = auto() + POST_WITH = auto() + POST_ALIAS = auto() POST_INDEX = auto() - POST_SCHEMA_ROOT = auto() - POST_SCHEMA_WITH = auto() UNSUPPORTED = auto() @classmethod @@ -1633,6 +1698,14 @@ class Table(Expression): "system_time": False, } + @property + def db(self) -> str: + return self.text("db") + + @property + def catalog(self) -> str: + return self.text("catalog") + # See the TSQL "Querying data in a system-versioned temporal table" page class SystemTime(Expression): @@ -1678,6 +1751,40 @@ class Union(Subqueryable): .limit(expression, dialect=dialect, copy=False, **opts) ) + def select( + self, + *expressions: str | Expression, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Union: + """Append to or set the SELECT of the union recursively. + + Example: + >>> from sqlglot import parse_one + >>> parse_one("select a from x union select a from y union select a from z").select("b").sql() + 'SELECT a, b FROM x UNION SELECT a, b FROM y UNION SELECT a, b FROM z' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Union: the modified expression. + """ + this = self.copy() if copy else self + this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts) + this.expression.unnest().select( + *expressions, append=append, dialect=dialect, copy=False, **opts + ) + return this + @property def named_selects(self): return self.this.unnest().named_selects @@ -1985,7 +2092,14 @@ class Select(Subqueryable): **opts, ) - def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def select( + self, + *expressions: str | Expression, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the SELECT expressions. @@ -1994,13 +2108,13 @@ class Select(Subqueryable): 'SELECT x, y' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2399,7 +2513,7 @@ class Star(Expression): class Parameter(Expression): - pass + arg_types = {"this": True, "wrapped": False} class SessionParameter(Expression): @@ -2428,6 +2542,7 @@ class DataType(Expression): "expressions": False, "nested": False, "values": False, + "prefix": False, } class Type(AutoName): @@ -2693,6 +2808,10 @@ class ILike(Binary, Predicate): pass +class ILikeAny(Binary, Predicate): + pass + + class IntDiv(Binary): pass @@ -2709,6 +2828,10 @@ class Like(Binary, Predicate): pass +class LikeAny(Binary, Predicate): + pass + + class LT(Binary, Predicate): pass @@ -3042,7 +3165,7 @@ class DateDiff(Func, TimeUnit): class DateTrunc(Func): - arg_types = {"this": True, "expression": True, "zone": False} + arg_types = {"unit": True, "this": True, "zone": False} class DatetimeAdd(Func, TimeUnit): @@ -3330,6 +3453,16 @@ class Reduce(Func): arg_types = {"this": True, "initial": True, "merge": True, "finish": False} +class RegexpExtract(Func): + arg_types = { + "this": True, + "expression": True, + "position": False, + "occurrence": False, + "group": False, + } + + class RegexpLike(Func): arg_types = {"this": True, "expression": True, "flag": False} @@ -3519,6 +3652,10 @@ class Week(Func): arg_types = {"this": True, "mode": False} +class XMLTable(Func): + arg_types = {"this": True, "passing": False, "columns": False, "by_ref": False} + + class Year(Func): pass @@ -3566,6 +3703,7 @@ def maybe_parse( into: t.Optional[IntoType] = None, dialect: DialectType = None, prefix: t.Optional[str] = None, + copy: bool = False, **opts, ) -> Expression: """Gracefully handle a possible string or expression. @@ -3583,6 +3721,7 @@ def maybe_parse( input expression is a SQL string). prefix: a string to prefix the sql with before it gets parsed (automatically includes a space) + copy: whether or not to copy the expression. **opts: other options to use to parse the input expressions (again, in the case that an input expression is a SQL string). @@ -3590,6 +3729,8 @@ def maybe_parse( Expression: the parsed or given expression. """ if isinstance(sql_or_expression, Expression): + if copy: + return sql_or_expression.copy() return sql_or_expression import sqlglot @@ -3818,7 +3959,7 @@ def except_(left, right, distinct=True, dialect=None, **opts): return Except(this=left, expression=right, distinct=distinct) -def select(*expressions, dialect=None, **opts) -> Select: +def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select: """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -3827,9 +3968,9 @@ def select(*expressions, dialect=None, **opts) -> Select: 'SELECT col1, col2 FROM tbl' Args: - *expressions (str | Expression): the SQL code string to parse as the expressions of a + *expressions: the SQL code string to parse as the expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expressions (in the case that an + dialect: the dialect used to parse the input expressions (in the case that an input expression is a SQL string). **opts: other options to use to parse the input expressions (again, in the case that an input expression is a SQL string). @@ -4219,19 +4360,27 @@ def subquery(expression, alias=None, dialect=None, **opts): return Select().from_(expression, dialect=dialect, **opts) -def column(col, table=None, quoted=None) -> Column: +def column( + col: str | Identifier, + table: t.Optional[str | Identifier] = None, + schema: t.Optional[str | Identifier] = None, + quoted: t.Optional[bool] = None, +) -> Column: """ Build a Column. Args: - col (str | Expression): column name - table (str | Expression): table name + col: column name + table: table name + schema: schema name + quoted: whether or not to force quote each part Returns: Column: column instance """ return Column( this=to_identifier(col, quoted=quoted), table=to_identifier(table, quoted=quoted), + schema=to_identifier(schema, quoted=quoted), ) @@ -4314,6 +4463,30 @@ def values( ) +def var(name: t.Optional[str | Expression]) -> Var: + """Build a SQL variable. + + Example: + >>> repr(var('x')) + '(VAR this: x)' + + >>> repr(var(column('x', table='y'))) + '(VAR this: x)' + + Args: + name: The name of the var or an expression who's name will become the var. + + Returns: + The new variable node. + """ + if not name: + raise ValueError(f"Cannot convert empty name into var.") + + if isinstance(name, Expression): + name = name.name + return Var(this=name) + + def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: """Build ALTER TABLE... RENAME... expression diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0d72fe3..1479e28 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,19 +1,16 @@ from __future__ import annotations import logging -import re import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages -from sqlglot.helper import apply_index_offset, csv +from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") -BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)") - class Generator: """ @@ -59,10 +56,14 @@ class Generator: """ TRANSFORMS = { - exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", - exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", - exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", - exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", e.this, e.expression, e.args.get("unit") + ), + exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), + exp.TsOrDsAdd: lambda self, e: self.func( + "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") + ), + exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), @@ -72,6 +73,17 @@ class Generator: exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", + exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", + exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", + exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", + exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE", + exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", + exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", + exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", + exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", + exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", + exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", + exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -89,8 +101,8 @@ class Generator: # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True - # Whether or not create function uses an AS before the def. - CREATE_FUNCTION_AS = True + # Whether or not create function uses an AS before the RETURN + CREATE_FUNCTION_RETURN_AS = True TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", @@ -110,42 +122,46 @@ class Generator: STRUCT_DELIMITER = ("<", ">") + PARAMETER_TOKEN = "@" + PROPERTIES_LOCATION = { - exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AfterJournalProperty: exp.Properties.Location.POST_NAME, exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, - exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA, - exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA, - exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, + exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, + exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, + exp.ChecksumProperty: exp.Properties.Location.POST_NAME, + exp.CollateProperty: exp.Properties.Location.POST_SCHEMA, + exp.Cluster: exp.Properties.Location.POST_SCHEMA, + exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME, exp.DefinerProperty: exp.Properties.Location.POST_CREATE, - exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA, - exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, - exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA, - exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA, - exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA, - exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.LogProperty: exp.Properties.Location.PRE_SCHEMA, - exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH, - exp.Property: exp.Properties.Location.POST_SCHEMA_WITH, - exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA, + exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, + exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, + exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA, + exp.FallbackProperty: exp.Properties.Location.POST_NAME, + exp.FileFormatProperty: exp.Properties.Location.POST_WITH, + exp.FreespaceProperty: exp.Properties.Location.POST_NAME, + exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, + exp.JournalProperty: exp.Properties.Location.POST_NAME, + exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, + exp.LockingProperty: exp.Properties.Location.POST_ALIAS, + exp.LogProperty: exp.Properties.Location.POST_NAME, + exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, + exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.Property: exp.Properties.Location.POST_WITH, + exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, + exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, + exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, + exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, - exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, - exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA, + exp.TableFormatProperty: exp.Properties.Location.POST_WITH, + exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA, + exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) @@ -173,7 +189,6 @@ class Generator: "null_ordering", "max_unsupported", "_indent", - "_replace_backslash", "_escaped_quote_end", "_escaped_identifier_end", "_leading_comma", @@ -230,7 +245,6 @@ class Generator: self.max_unsupported = max_unsupported self.null_ordering = null_ordering self._indent = indent - self._replace_backslash = self.string_escape == "\\" self._escaped_quote_end = self.string_escape + self.quote_end self._escaped_identifier_end = self.identifier_escape + self.identifier_end self._leading_comma = leading_comma @@ -403,12 +417,13 @@ class Generator: def column_sql(self, expression: exp.Column) -> str: return ".".join( - part - for part in [ - self.sql(expression, "db"), - self.sql(expression, "table"), - self.sql(expression, "this"), - ] + self.sql(part) + for part in ( + expression.args.get("catalog"), + expression.args.get("db"), + expression.args.get("table"), + expression.args.get("this"), + ) if part ) @@ -430,26 +445,6 @@ class Generator: def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) - def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: - this = self.sql(expression, "this") - return f"CHECK ({this})" - - def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str: - comment = self.sql(expression, "this") - return f"COMMENT {comment}" - - def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str: - collate = self.sql(expression, "this") - return f"COLLATE {collate}" - - def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str: - encode = self.sql(expression, "this") - return f"ENCODE {encode}" - - def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str: - default = self.sql(expression, "this") - return f"DEFAULT {default}" - def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: @@ -459,10 +454,19 @@ class Generator: start = expression.args.get("start") start = f"START WITH {start}" if start else "" increment = expression.args.get("increment") - increment = f"INCREMENT BY {increment}" if increment else "" + increment = f" INCREMENT BY {increment}" if increment else "" + minvalue = expression.args.get("minvalue") + minvalue = f" MINVALUE {minvalue}" if minvalue else "" + maxvalue = expression.args.get("maxvalue") + maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" + cycle = expression.args.get("cycle") + cycle_sql = "" + if cycle is not None: + cycle_sql = f"{' NO' if not cycle else ''} CYCLE" + cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql sequence_opts = "" - if start or increment: - sequence_opts = f"{start} {increment}" + if start or increment or cycle_sql: + sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" sequence_opts = f" ({sequence_opts.strip()})" return f"GENERATED{this}AS IDENTITY{sequence_opts}" @@ -483,22 +487,22 @@ class Generator: properties = expression.args.get("properties") properties_exp = expression.copy() properties_locs = self.locate_properties(properties) if properties else {} - if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get( - exp.Properties.Location.POST_SCHEMA_WITH + if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get( + exp.Properties.Location.POST_WITH ): properties_exp.set( "properties", exp.Properties( expressions=[ - *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT], - *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH], + *properties_locs[exp.Properties.Location.POST_SCHEMA], + *properties_locs[exp.Properties.Location.POST_WITH], ] ), ) - if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA): + if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME): this_name = self.sql(expression.this, "this") this_properties = self.properties( - exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]), + exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]), wrapped=False, ) this_schema = f"({self.expressions(expression.this)})" @@ -512,8 +516,17 @@ class Generator: if expression_sql: expression_sql = f"{begin}{self.sep()}{expression_sql}" - if self.CREATE_FUNCTION_AS or kind != "FUNCTION": - expression_sql = f" AS{expression_sql}" + if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return): + if properties_locs.get(exp.Properties.Location.POST_ALIAS): + postalias_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_ALIAS] + ), + wrapped=False, + ) + expression_sql = f" AS {postalias_props_sql}{expression_sql}" + else: + expression_sql = f" AS{expression_sql}" temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( @@ -736,9 +749,9 @@ class Generator: for p in expression.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.POST_SCHEMA_WITH: + if p_loc == exp.Properties.Location.POST_WITH: with_properties.append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: + elif p_loc == exp.Properties.Location.POST_SCHEMA: root_properties.append(p) return self.root_properties( @@ -776,16 +789,18 @@ class Generator: for p in properties.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.PRE_SCHEMA: - properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p) + if p_loc == exp.Properties.Location.POST_NAME: + properties_locs[exp.Properties.Location.POST_NAME].append(p) elif p_loc == exp.Properties.Location.POST_INDEX: properties_locs[exp.Properties.Location.POST_INDEX].append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: - properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH: - properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA: + properties_locs[exp.Properties.Location.POST_SCHEMA].append(p) + elif p_loc == exp.Properties.Location.POST_WITH: + properties_locs[exp.Properties.Location.POST_WITH].append(p) elif p_loc == exp.Properties.Location.POST_CREATE: properties_locs[exp.Properties.Location.POST_CREATE].append(p) + elif p_loc == exp.Properties.Location.POST_ALIAS: + properties_locs[exp.Properties.Location.POST_ALIAS].append(p) elif p_loc == exp.Properties.Location.UNSUPPORTED: self.unsupported(f"Unsupported property {p.key}") @@ -899,6 +914,14 @@ class Generator: for_ = " FOR NONE" return f"WITH{no}{concurrent} ISOLATED LOADING{for_}" + def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: + kind = expression.args.get("kind") + this: str = f" {this}" if expression.this else "" + for_or_in = expression.args.get("for_or_in") + lock_type = expression.args.get("lock_type") + override = " OVERRIDE" if expression.args.get("override") else "" + return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}" + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") @@ -907,14 +930,17 @@ class Generator: else: this = "OVERWRITE TABLE " if overwrite else "INTO " + alternative = expression.args.get("alternative") + alternative = f" OR {alternative} " if alternative else " " this = f"{this}{self.sql(expression, 'this')}" + exists = " IF EXISTS " if expression.args.get("exists") else " " partition_sql = ( self.sql(expression, "partition") if expression.args.get("partition") else "" ) expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" - sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" + sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1046,21 +1072,26 @@ class Generator: f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" ) - cube = expression.args.get("cube") - if cube is True: - cube = self.seg("WITH CUBE") + cube = expression.args.get("cube", []) + if seq_get(cube, 0) is True: + return f"{group_by}{self.seg('WITH CUBE')}" else: - cube = self.expressions(expression, key="cube", indent=False) - cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" + cube_sql = self.expressions(expression, key="cube", indent=False) + cube_sql = f"{self.seg('CUBE')} {self.wrap(cube_sql)}" if cube_sql else "" - rollup = expression.args.get("rollup") - if rollup is True: - rollup = self.seg("WITH ROLLUP") + rollup = expression.args.get("rollup", []) + if seq_get(rollup, 0) is True: + return f"{group_by}{self.seg('WITH ROLLUP')}" else: - rollup = self.expressions(expression, key="rollup", indent=False) - rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" + rollup_sql = self.expressions(expression, key="rollup", indent=False) + rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else "" + + groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",") - return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}" + if expression.args.get("expressions") and groupings: + group_by = f"{group_by}," + + return f"{group_by}{groupings}" def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) @@ -1139,8 +1170,6 @@ class Generator: def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: - if self._replace_backslash: - text = BACKSLASH_RE.sub(r"\\\\", text) text = text.replace(self.quote_end, self._escaped_quote_end) if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) @@ -1291,7 +1320,9 @@ class Generator: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" def parameter_sql(self, expression: exp.Parameter) -> str: - return f"@{self.sql(expression, 'this')}" + this = self.sql(expression, "this") + this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}" + return f"{self.PARAMETER_TOKEN}{this}" def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") @@ -1405,7 +1436,10 @@ class Generator: return f"ALL {self.wrap(expression)}" def any_sql(self, expression: exp.Any) -> str: - return f"ANY {self.wrap(expression)}" + this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subqueryable): + this = self.wrap(this) + return f"ANY {this}" def exists_sql(self, expression: exp.Exists) -> str: return f"EXISTS{self.wrap(expression)}" @@ -1444,11 +1478,11 @@ class Generator: trim_type = self.sql(expression, "position") if trim_type == "LEADING": - return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})" + return self.func("LTRIM", expression.this) elif trim_type == "TRAILING": - return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})" + return self.func("RTRIM", expression.this) else: - return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})" + return self.func("TRIM", expression.this, expression.expression) def concat_sql(self, expression: exp.Concat) -> str: if len(expression.expressions) == 1: @@ -1530,8 +1564,7 @@ class Generator: return f"REFERENCES {this}{expressions}{options}" def anonymous_sql(self, expression: exp.Anonymous) -> str: - args = self.format_args(*expression.expressions) - return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" + return self.func(expression.name, *expression.expressions) def paren_sql(self, expression: exp.Paren) -> str: if isinstance(expression.unnest(), exp.Select): @@ -1792,7 +1825,10 @@ class Generator: else: args.append(arg_value) - return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" + return self.func(expression.sql_name(), *args) + + def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str: + return f"{self.normalize_func(name)}({self.format_args(*args)})" def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) @@ -1848,6 +1884,7 @@ class Generator: return self.indent(result_sql, skip_first=False) if indent else result_sql def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: + flat = flat or isinstance(expression.parent, exp.Properties) expressions_sql = self.expressions(expression, flat=flat) if flat: return f"{op} {expressions_sql}" @@ -1880,11 +1917,6 @@ class Generator: ) return f"{this}{expressions}" - def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: - this = self.sql(expression, "this") - kind = self.sql(expression, "kind") - return f"{this} {kind}" - def joinhint_sql(self, expression: exp.JoinHint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 66f97a9..be65ab9 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -280,6 +280,9 @@ class TypeAnnotator: } # First annotate the current scope's column references for col in scope.columns: + if not col.table: + continue + source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index c6bea5a..6f9db82 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -81,9 +81,7 @@ def eliminate_subqueries(expression): new_ctes.append(cte_scope.expression.parent) # Now append the rest - for scope in itertools.chain( - root.union_scopes, root.subquery_scopes, root.derived_table_scopes - ): + for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): for child_scope in scope.traverse(): new_cte = _eliminate(child_scope, existing_ctes, taken) if new_cte: @@ -99,7 +97,7 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_union: return _eliminate_union(scope, existing_ctes, taken) - if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): + if scope.is_derived_table: return _eliminate_derived_table(scope, existing_ctes, taken) if scope.is_cte: diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 96fd56b..d9d04be 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,4 +1,10 @@ +from __future__ import annotations + +import typing as t + import sqlglot +from sqlglot import Schema, exp +from sqlglot.dialects.dialect import DialectType from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes @@ -24,8 +30,8 @@ RULES = ( isolate_table_selects, qualify_columns, expand_laterals, - validate_qualify_columns, pushdown_projections, + validate_qualify_columns, normalize, unnest_subqueries, expand_multi_table_selects, @@ -40,22 +46,31 @@ RULES = ( ) -def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs): +def optimize( + expression: str | exp.Expression, + schema: t.Optional[dict | Schema] = None, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + dialect: DialectType = None, + rules: t.Sequence[t.Callable] = RULES, + **kwargs, +): """ Rewrite a sqlglot AST into an optimized form. Args: - expression (sqlglot.Expression): expression to optimize - schema (dict|sqlglot.optimizer.Schema): database schema. + expression: expression to optimize + schema: database schema. This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of the following forms: 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} If no schema is provided then the default schema defined at `sqlgot.schema` will be used - db (str): specify the default database, as might be set by a `USE DATABASE db` statement - catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement - rules (sequence): sequence of optimizer rules to use. + db: specify the default database, as might be set by a `USE DATABASE db` statement + catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement + dialect: The dialect to parse the sql string. + rules: sequence of optimizer rules to use. Many of the rules require tables and columns to be qualified. Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know what you're doing! @@ -65,7 +80,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar """ schema = ensure_schema(schema or sqlglot.schema) possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} - expression = expression.copy() + expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 54c5021..3f360f9 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -1,7 +1,10 @@ from collections import defaultdict from sqlglot import alias, exp +from sqlglot.helper import flatten +from sqlglot.optimizer.qualify_columns import Resolver from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() @@ -10,7 +13,7 @@ SELECT_ALL = object() DEFAULT_SELECTION = lambda: alias("1", "_") -def pushdown_projections(expression): +def pushdown_projections(expression, schema=None): """ Rewrite sqlglot AST to remove unused columns projections. @@ -27,9 +30,9 @@ def pushdown_projections(expression): sqlglot.Expression: optimized expression """ # Map of Scope to all columns being selected by outer queries. + schema = ensure_schema(schema) referenced_columns = defaultdict(set) - left_union = None - right_union = None + # We build the scope tree (which is traversed in DFS postorder), then iterate # over the result in reverse order. This should ensure that the set of selected # columns for a particular scope are completely build by the time we get to it. @@ -41,16 +44,20 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left_union, right_union = scope.union_scopes - referenced_columns[left_union] = parent_selections - referenced_columns[right_union] = parent_selections + left, right = scope.union_scopes + referenced_columns[left] = parent_selections + + if any(select.is_star for select in right.selects): + referenced_columns[right] = parent_selections + elif not any(select.is_star for select in left.selects): + referenced_columns[right] = [ + right.selects[i].alias_or_name + for i, select in enumerate(left.selects) + if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections + ] - if isinstance(scope.expression, exp.Select) and scope != right_union: - removed_indexes = _remove_unused_selections(scope, parent_selections) - # The left union is used for column names to select and if we remove columns from the left - # we need to also remove those same columns in the right that were at the same position - if scope is left_union: - _remove_indexed_selections(right_union, removed_indexes) + if isinstance(scope.expression, exp.Select): + _remove_unused_selections(scope, parent_selections, schema) # Group columns by source name selects = defaultdict(set) @@ -68,8 +75,7 @@ def pushdown_projections(expression): return expression -def _remove_unused_selections(scope, parent_selections): - removed_indexes = [] +def _remove_unused_selections(scope, parent_selections, schema): order = scope.expression.args.get("order") if order: @@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections): else: order_refs = set() - new_selections = [] + new_selections = defaultdict(list) removed = False - for i, selection in enumerate(scope.selects): - if ( - SELECT_ALL in parent_selections - or selection.alias_or_name in parent_selections - or selection.alias_or_name in order_refs - ): - new_selections.append(selection) + star = False + for selection in scope.selects: + name = selection.alias_or_name + + if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: + new_selections[name].append(selection) else: - removed_indexes.append(i) + if selection.is_star: + star = True removed = True + if star: + resolver = Resolver(scope, schema) + + for name in sorted(parent_selections): + if name not in new_selections: + new_selections[name].append( + alias(exp.column(name, table=resolver.get_table(name)), name) + ) + # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION()) + new_selections[""].append(DEFAULT_SELECTION()) + + scope.expression.select(*flatten(new_selections.values()), append=False, copy=False) - scope.expression.set("expressions", new_selections) if removed: scope.clear_cache() - return removed_indexes - - -def _remove_indexed_selections(scope, indexes_to_remove): - new_selections = [ - selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove - ] - if not new_selections: - new_selections.append(DEFAULT_SELECTION()) - scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ab13d01..a7bd9b5 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -27,17 +27,16 @@ def qualify_columns(expression, schema): schema = ensure_schema(schema) for scope in traverse_scope(expression): - resolver = _Resolver(scope, schema) + resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) _expand_using(scope, resolver) - _expand_group_by(scope, resolver) _qualify_columns(scope, resolver) - _expand_order_by(scope) if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) - + _expand_group_by(scope, resolver) + _expand_order_by(scope) return expression @@ -48,7 +47,8 @@ def validate_qualify_columns(expression): if isinstance(scope.expression, exp.Select): unqualified_columns.extend(scope.unqualified_columns) if scope.external_columns and not scope.is_correlated_subquery: - raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") + column = scope.external_columns[0] + raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") if unqualified_columns: raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") @@ -62,8 +62,6 @@ def _pop_table_column_aliases(derived_tables): (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: - if isinstance(derived_table.unnest(), exp.UDTF): - continue table_alias = derived_table.args.get("alias") if table_alias: table_alias.args.pop("columns", None) @@ -206,7 +204,7 @@ def _qualify_columns(scope, resolver): if column_table and column_table in scope.sources: source_columns = resolver.get_source_columns(column_table) - if source_columns and column_name not in source_columns: + if source_columns and column_name not in source_columns and "*" not in source_columns: raise OptimizeError(f"Unknown column: {column_name}") if not column_table: @@ -256,7 +254,7 @@ def _expand_stars(scope, resolver): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): + elif expression.is_star: tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -268,17 +266,16 @@ def _expand_stars(scope, resolver): if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) - if not columns: - raise OptimizeError( - f"Table has no schema/columns. Cannot expand star for table: {table}." - ) - table_id = id(table) - for name in columns: - if name not in except_columns.get(table_id, set()): - alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table) - new_selections.append(alias(column, alias_) if alias_ != name else column) + if columns and "*" not in columns: + table_id = id(table) + for name in columns: + if name not in except_columns.get(table_id, set()): + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table) + new_selections.append(alias(column, alias_) if alias_ != name else column) + else: + return scope.expression.set("expressions", new_selections) @@ -316,7 +313,7 @@ def _qualify_outputs(scope): if isinstance(selection, exp.Subquery): if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) - elif not isinstance(selection, exp.Alias): + elif not isinstance(selection, exp.Alias) and not selection.is_star: alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") alias_.set("this", selection) selection = alias_ @@ -329,7 +326,7 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -class _Resolver: +class Resolver: """ Helper for resolving columns. @@ -361,7 +358,9 @@ class _Resolver: if not table: sources_without_schema = tuple( - source for source, columns in self._get_all_source_columns().items() if not columns + source + for source, columns in self._get_all_source_columns().items() + if not columns or "*" in columns ) if len(sources_without_schema) == 1: return sources_without_schema[0] @@ -397,7 +396,8 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: self._source_columns = { - k: self.get_source_columns(k) for k in self.scope.selected_sources + k: self.get_source_columns(k) + for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) } return self._source_columns @@ -436,7 +436,7 @@ class _Resolver: Find the unique columns in a list of columns. Example: - >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) + >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) ['a', 'c'] This is necessary because duplicate column names are ambiguous. diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 65593bd..6e50182 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -28,7 +28,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): next_name = lambda: f"_q_{next(sequence)}" for scope in traverse_scope(expression): - for derived_table in scope.ctes + scope.derived_tables: + for derived_table in itertools.chain(scope.ctes, scope.derived_tables): if not derived_table.args.get("alias"): alias_ = f"_q_{next(sequence)}" derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 8565c64..335ff3e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -26,6 +26,10 @@ class Scope: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals + For example: + SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; + The LATERAL VIEW EXPLODE gets x as a source. outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -34,8 +38,10 @@ class Scope: parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent subquery_scopes (list[Scope]): List of all child scopes for subqueries - cte_scopes = (list[Scope]) List of all child scopes for CTEs - derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables + cte_scopes (list[Scope]): List of all child scopes for CTEs + derived_table_scopes (list[Scope]): List of all child scopes for derived_tables + udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions + table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes. """ @@ -47,22 +53,28 @@ class Scope: outer_column_list=None, parent=None, scope_type=ScopeType.ROOT, + lateral_sources=None, ): self.expression = expression self.sources = sources or {} + self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + self.sources.update(self.lateral_sources) self.outer_column_list = outer_column_list or [] self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] self.derived_table_scopes = [] + self.table_scopes = [] self.cte_scopes = [] self.union_scopes = [] + self.udtf_scopes = [] self.clear_cache() def clear_cache(self): self._collected = False self._raw_columns = None self._derived_tables = None + self._udtfs = None self._tables = None self._ctes = None self._subqueries = None @@ -86,6 +98,7 @@ class Scope: self._ctes = [] self._subqueries = [] self._derived_tables = [] + self._udtfs = [] self._raw_columns = [] self._join_hints = [] @@ -99,7 +112,7 @@ class Scope: elif isinstance(node, exp.JoinHint): self._join_hints.append(node) elif isinstance(node, exp.UDTF): - self._derived_tables.append(node) + self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): @@ -199,6 +212,17 @@ class Scope: self._ensure_collected() return self._derived_tables + @property + def udtfs(self): + """ + List of "User Defined Tabular Functions" in this scope. + + Returns: + list[exp.UDTF]: UDTFs + """ + self._ensure_collected() + return self._udtfs + @property def subqueries(self): """ @@ -227,7 +251,9 @@ class Scope: columns = self._raw_columns external_columns = [ - column for scope in self.subquery_scopes for column in scope.external_columns + column + for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) + for column in scope.external_columns ] named_selects = set(self.expression.named_selects) @@ -262,9 +288,8 @@ class Scope: for table in self.tables: referenced_names.append((table.alias_or_name, table)) - for derived_table in self.derived_tables: - referenced_names.append((derived_table.alias, derived_table.unnest())) - + for expression in itertools.chain(self.derived_tables, self.udtfs): + referenced_names.append((expression.alias, expression.unnest())) result = {} for name, node in referenced_names: @@ -414,7 +439,7 @@ class Scope: Scope: scope instances in depth-first-search post-order """ for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes + self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes ): yield from child_scope.traverse() yield self @@ -480,24 +505,23 @@ def _traverse_scope(scope): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) - elif isinstance(scope.expression, exp.UDTF): - _set_udtf_scope(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) + elif isinstance(scope.expression, exp.UDTF): + pass else: raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") yield scope def _traverse_select(scope): - yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) - yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) + yield from _traverse_ctes(scope) + yield from _traverse_tables(scope) yield from _traverse_subqueries(scope) - _add_table_sources(scope) def _traverse_union(scope): - yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) + yield from _traverse_ctes(scope) # The last scope to be yield should be the top most scope left = None @@ -511,82 +535,98 @@ def _traverse_union(scope): scope.union_scopes = [left, right] -def _set_udtf_scope(scope): - parent = scope.expression.parent - from_ = parent.args.get("from") - - if not from_: - return - - for table in from_.expressions: - if isinstance(table, exp.Table): - scope.tables.append(table) - elif isinstance(table, exp.Subquery): - scope.subqueries.append(table) - _add_table_sources(scope) - _traverse_subqueries(scope) - - -def _traverse_derived_tables(derived_tables, scope, scope_type): +def _traverse_ctes(scope): sources = {} - is_cte = scope_type == ScopeType.CTE - for derived_table in derived_tables: + for cte in scope.ctes: recursive_scope = None # if the scope is a recursive cte, it must be in the form of # base_case UNION recursive. thus the recursive scope is the first # section of the union. - if is_cte and scope.expression.args["with"].recursive: - union = derived_table.this + if scope.expression.args["with"].recursive: + union = cte.this if isinstance(union, exp.Union): recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) for child_scope in _traverse_scope( scope.branch( - derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, - chain_sources=sources if scope_type == ScopeType.CTE else None, - outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type, + cte.this, + chain_sources=sources, + outer_column_list=cte.alias_column_names, + scope_type=ScopeType.CTE, ) ): yield child_scope - # Tables without aliases will be set as "" - # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. - # Until then, this means that only a single, unaliased derived table is allowed (rather, - # the latest one wins. - alias = derived_table.alias + alias = cte.alias sources[alias] = child_scope if recursive_scope: child_scope.add_source(alias, recursive_scope) # append the final child_scope yielded - if is_cte: - scope.cte_scopes.append(child_scope) - else: - scope.derived_table_scopes.append(child_scope) + scope.cte_scopes.append(child_scope) scope.sources.update(sources) -def _add_table_sources(scope): +def _traverse_tables(scope): sources = {} - for table in scope.tables: - table_name = table.name - if table.alias: - source_name = table.alias - else: - source_name = table_name + # Traverse FROMs, JOINs, and LATERALs in the order they are defined + expressions = [] + from_ = scope.expression.args.get("from") + if from_: + expressions.extend(from_.expressions) - if table_name in scope.sources: - # This is a reference to a parent source (e.g. a CTE), not an actual table. - scope.sources[source_name] = scope.sources[table_name] + for join in scope.expression.args.get("joins") or []: + expressions.append(join.this) + + expressions.extend(scope.expression.args.get("laterals") or []) + + for expression in expressions: + if isinstance(expression, exp.Table): + table_name = expression.name + source_name = expression.alias_or_name + + if table_name in scope.sources: + # This is a reference to a parent source (e.g. a CTE), not an actual table. + sources[source_name] = scope.sources[table_name] + else: + sources[source_name] = expression + continue + + if isinstance(expression, exp.UDTF): + lateral_sources = sources + scope_type = ScopeType.UDTF + scopes = scope.udtf_scopes else: - sources[source_name] = table + lateral_sources = None + scope_type = ScopeType.DERIVED_TABLE + scopes = scope.derived_table_scopes + + for child_scope in _traverse_scope( + scope.branch( + expression, + lateral_sources=lateral_sources, + outer_column_list=expression.alias_column_names, + scope_type=scope_type, + ) + ): + yield child_scope + + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + alias = expression.alias + sources[alias] = child_scope + + # append the final child_scope yielded + scopes.append(child_scope) + scope.table_scopes.append(child_scope) scope.sources.update(sources) @@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True): if node is expression: continue - elif isinstance(node, exp.CTE): - prune = True - elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): - prune = True - elif isinstance(node, exp.Subqueryable): + if ( + isinstance(node, exp.CTE) + or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) + or isinstance(node, exp.UDTF) + or isinstance(node, exp.Subqueryable) + ): prune = True diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 579c2ce..9bde696 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import typing as t +from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors @@ -157,7 +158,6 @@ class Parser(metaclass=_Parser): ID_VAR_TOKENS = { TokenType.VAR, - TokenType.ALWAYS, TokenType.ANTI, TokenType.APPLY, TokenType.AUTO_INCREMENT, @@ -186,8 +186,6 @@ class Parser(metaclass=_Parser): TokenType.FOLLOWING, TokenType.FORMAT, TokenType.FUNCTION, - TokenType.GENERATED, - TokenType.IDENTITY, TokenType.IF, TokenType.INDEX, TokenType.ISNULL, @@ -213,7 +211,6 @@ class Parser(metaclass=_Parser): TokenType.ROW, TokenType.ROWS, TokenType.SCHEMA, - TokenType.SCHEMA_COMMENT, TokenType.SEED, TokenType.SEMI, TokenType.SET, @@ -481,9 +478,7 @@ class Parser(metaclass=_Parser): PLACEHOLDER_PARSERS = { TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), - TokenType.PARAMETER: lambda self: self.expression( - exp.Parameter, this=self._parse_var() or self._parse_primary() - ), + TokenType.PARAMETER: lambda self: self._parse_parameter(), TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text) if self._match_set((TokenType.NUMBER, TokenType.VAR)) else None, @@ -516,6 +511,9 @@ class Parser(metaclass=_Parser): PROPERTY_PARSERS = { "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), "CHARACTER SET": lambda self: self._parse_character_set(), + "CLUSTER BY": lambda self: self.expression( + exp.Cluster, expressions=self._parse_csv(self._parse_ordered) + ), "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), "PARTITION BY": lambda self: self._parse_partitioned_by(), "PARTITIONED BY": lambda self: self._parse_partitioned_by(), @@ -576,20 +574,54 @@ class Parser(metaclass=_Parser): "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "DEFINER": lambda self: self._parse_definer(), + "LOCK": lambda self: self._parse_locking(), + "LOCKING": lambda self: self._parse_locking(), } CONSTRAINT_PARSERS = { - TokenType.CHECK: lambda self: self.expression( - exp.Check, this=self._parse_wrapped(self._parse_conjunction) + "AUTOINCREMENT": lambda self: self._parse_auto_increment(), + "AUTO_INCREMENT": lambda self: self._parse_auto_increment(), + "CASESPECIFIC": lambda self: self.expression(exp.CaseSpecificColumnConstraint, not_=False), + "CHARACTER SET": lambda self: self.expression( + exp.CharacterSetColumnConstraint, this=self._parse_var_or_string() + ), + "CHECK": lambda self: self.expression( + exp.CheckColumnConstraint, this=self._parse_wrapped(self._parse_conjunction) + ), + "COLLATE": lambda self: self.expression( + exp.CollateColumnConstraint, this=self._parse_var() + ), + "COMMENT": lambda self: self.expression( + exp.CommentColumnConstraint, this=self._parse_string() + ), + "DEFAULT": lambda self: self.expression( + exp.DefaultColumnConstraint, this=self._parse_bitwise() ), - TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), - TokenType.UNIQUE: lambda self: self._parse_unique(), - TokenType.LIKE: lambda self: self._parse_create_like(), + "ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()), + "FOREIGN KEY": lambda self: self._parse_foreign_key(), + "FORMAT": lambda self: self.expression( + exp.DateFormatColumnConstraint, this=self._parse_var_or_string() + ), + "GENERATED": lambda self: self._parse_generated_as_identity(), + "IDENTITY": lambda self: self._parse_auto_increment(), + "LIKE": lambda self: self._parse_create_like(), + "NOT": lambda self: self._parse_not_constraint(), + "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True), + "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), + "PRIMARY KEY": lambda self: self._parse_primary_key(), + "TITLE": lambda self: self.expression( + exp.TitleColumnConstraint, this=self._parse_var_or_string() + ), + "UNIQUE": lambda self: self._parse_unique(), + "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), } + SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} + NO_PAREN_FUNCTION_PARSERS = { TokenType.CASE: lambda self: self._parse_case(), TokenType.IF: lambda self: self._parse_if(), + TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -637,6 +669,8 @@ class Parser(metaclass=_Parser): TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -940,7 +974,9 @@ class Parser(metaclass=_Parser): def _parse_create(self) -> t.Optional[exp.Expression]: start = self._prev - replace = self._match_pair(TokenType.OR, TokenType.REPLACE) + replace = self._prev.text.upper() == "REPLACE" or self._match_pair( + TokenType.OR, TokenType.REPLACE + ) set_ = self._match(TokenType.SET) # Teradata multiset = self._match_text_seq("MULTISET") # Teradata global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata @@ -958,7 +994,7 @@ class Parser(metaclass=_Parser): create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - properties = self._parse_properties() + properties = self._parse_properties() # exp.Properties.Location.POST_CREATE create_token = self._match_set(self.CREATABLES) and self._prev if not properties or not create_token: @@ -994,15 +1030,37 @@ class Parser(metaclass=_Parser): ): table_parts = self._parse_table_parts(schema=True) - if self._match(TokenType.COMMA): # comma-separated properties before schema definition - properties = self._parse_properties(before=True) + # exp.Properties.Location.POST_NAME + if self._match(TokenType.COMMA): + temp_properties = self._parse_properties(before=True) + if properties and temp_properties: + properties.expressions.append(temp_properties.expressions) + elif temp_properties: + properties = temp_properties this = self._parse_schema(this=table_parts) - if not properties: # properties after schema definition - properties = self._parse_properties() + # exp.Properties.Location.POST_SCHEMA and POST_WITH + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.append(temp_properties.expressions) + elif temp_properties: + properties = temp_properties self._match(TokenType.ALIAS) + + # exp.Properties.Location.POST_ALIAS + if not ( + self._match(TokenType.SELECT, advance=False) + or self._match(TokenType.WITH, advance=False) + or self._match(TokenType.L_PAREN, advance=False) + ): + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.append(temp_properties.expressions) + elif temp_properties: + properties = temp_properties + expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: @@ -1022,12 +1080,13 @@ class Parser(metaclass=_Parser): while True: index = self._parse_create_table_index() - # post index PARTITION BY property + # exp.Properties.Location.POST_INDEX if self._match(TokenType.PARTITION_BY, advance=False): - if properties: - properties.expressions.append(self._parse_property()) - else: - properties = self._parse_properties() + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.append(temp_properties.expressions) + elif temp_properties: + properties = temp_properties if not index: break @@ -1080,7 +1139,7 @@ class Parser(metaclass=_Parser): return self.PROPERTY_PARSERS[self._prev.text.upper()](self) if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): - return self._parse_character_set(True) + return self._parse_character_set(default=True) if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): return self._parse_sortkey(compound=True) @@ -1240,7 +1299,7 @@ class Parser(metaclass=_Parser): def _parse_blockcompression(self) -> exp.Expression: self._match_text_seq("BLOCKCOMPRESSION") self._match(TokenType.EQ) - always = self._match(TokenType.ALWAYS) + always = self._match_text_seq("ALWAYS") manual = self._match_text_seq("MANUAL") never = self._match_text_seq("NEVER") default = self._match_text_seq("DEFAULT") @@ -1274,6 +1333,56 @@ class Parser(metaclass=_Parser): for_none=for_none, ) + def _parse_locking(self) -> exp.Expression: + if self._match(TokenType.TABLE): + kind = "TABLE" + elif self._match(TokenType.VIEW): + kind = "VIEW" + elif self._match(TokenType.ROW): + kind = "ROW" + elif self._match_text_seq("DATABASE"): + kind = "DATABASE" + else: + kind = None + + if kind in ("DATABASE", "TABLE", "VIEW"): + this = self._parse_table_parts() + else: + this = None + + if self._match(TokenType.FOR): + for_or_in = "FOR" + elif self._match(TokenType.IN): + for_or_in = "IN" + else: + for_or_in = None + + if self._match_text_seq("ACCESS"): + lock_type = "ACCESS" + elif self._match_texts(("EXCL", "EXCLUSIVE")): + lock_type = "EXCLUSIVE" + elif self._match_text_seq("SHARE"): + lock_type = "SHARE" + elif self._match_text_seq("READ"): + lock_type = "READ" + elif self._match_text_seq("WRITE"): + lock_type = "WRITE" + elif self._match_text_seq("CHECKSUM"): + lock_type = "CHECKSUM" + else: + lock_type = None + + override = self._match_text_seq("OVERRIDE") + + return self.expression( + exp.LockingProperty, + this=this, + kind=kind, + for_or_in=for_or_in, + lock_type=lock_type, + override=override, + ) + def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]: if self._match(TokenType.PARTITION_BY): return self._parse_csv(self._parse_conjunction) @@ -1351,6 +1460,7 @@ class Parser(metaclass=_Parser): this: t.Optional[exp.Expression] + alternative = None if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, @@ -1359,6 +1469,9 @@ class Parser(metaclass=_Parser): row_format=self._parse_row_format(match_row=True), ) else: + if self._match(TokenType.OR): + alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text + self._match(TokenType.INTO) self._match(TokenType.TABLE) this = self._parse_table(schema=True) @@ -1370,6 +1483,7 @@ class Parser(metaclass=_Parser): partition=self._parse_partition(), expression=self._parse_ddl_select(), overwrite=overwrite, + alternative=alternative, ) def _parse_row(self) -> t.Optional[exp.Expression]: @@ -1607,7 +1721,7 @@ class Parser(metaclass=_Parser): index = self._index if self._match(TokenType.L_PAREN): - columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var())) + columns = self._parse_csv(self._parse_function_parameter) self._match_r_paren() if columns else self._retreat(index) else: columns = None @@ -2080,27 +2194,33 @@ class Parser(metaclass=_Parser): if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None - expressions = self._parse_csv(self._parse_conjunction) - grouping_sets = self._parse_grouping_sets() + elements = defaultdict(list) - self._match(TokenType.COMMA) - with_ = self._match(TokenType.WITH) - cube = self._match(TokenType.CUBE) and ( - with_ or self._parse_wrapped_csv(self._parse_column) - ) + while True: + expressions = self._parse_csv(self._parse_conjunction) + if expressions: + elements["expressions"].extend(expressions) - self._match(TokenType.COMMA) - rollup = self._match(TokenType.ROLLUP) and ( - with_ or self._parse_wrapped_csv(self._parse_column) - ) + grouping_sets = self._parse_grouping_sets() + if grouping_sets: + elements["grouping_sets"].extend(grouping_sets) - return self.expression( - exp.Group, - expressions=expressions, - grouping_sets=grouping_sets, - cube=cube, - rollup=rollup, - ) + rollup = None + cube = None + + with_ = self._match(TokenType.WITH) + if self._match(TokenType.ROLLUP): + rollup = with_ or self._parse_wrapped_csv(self._parse_column) + elements["rollup"].extend(ensure_list(rollup)) + + if self._match(TokenType.CUBE): + cube = with_ or self._parse_wrapped_csv(self._parse_column) + elements["cube"].extend(ensure_list(cube)) + + if not (expressions or grouping_sets or rollup or cube): + break + + return self.expression(exp.Group, **elements) # type: ignore def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.GROUPING_SETS): @@ -2357,6 +2477,8 @@ class Parser(metaclass=_Parser): def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: index = self._index + prefix = self._match_text_seq("SYSUDTLIB", ".") + if not self._match_set(self.TYPE_TOKENS): return None @@ -2458,6 +2580,7 @@ class Parser(metaclass=_Parser): expressions=expressions, nested=nested, values=values, + prefix=prefix, ) def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: @@ -2512,8 +2635,14 @@ class Parser(metaclass=_Parser): if op: this = op(self, this, field) - elif isinstance(this, exp.Column) and not this.table: - this = self.expression(exp.Column, this=field, table=this.this) + elif isinstance(this, exp.Column) and not this.args.get("catalog"): + this = self.expression( + exp.Column, + this=field, + table=this.this, + db=this.args.get("table"), + catalog=this.args.get("db"), + ) else: this = self.expression(exp.Dot, this=this, expression=field) this = self._parse_bracket(this) @@ -2632,6 +2761,9 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) + def _parse_function_parameter(self) -> t.Optional[exp.Expression]: + return self._parse_column_def(self._parse_id_var()) + def _parse_user_defined_function( self, kind: t.Optional[TokenType] = None ) -> t.Optional[exp.Expression]: @@ -2643,7 +2775,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return this - expressions = self._parse_csv(self._parse_udf_kwarg) + expressions = self._parse_csv(self._parse_function_parameter) self._match_r_paren() return self.expression( exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True @@ -2669,15 +2801,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.SessionParameter, this=this, kind=kind) - def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]: - this = self._parse_id_var() - kind = self._parse_types() - - if not kind: - return this - - return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind) - def _parse_lambda(self) -> t.Optional[exp.Expression]: index = self._index @@ -2726,6 +2849,9 @@ class Parser(metaclass=_Parser): def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: kind = self._parse_types() + if self._match_text_seq("FOR", "ORDINALITY"): + return self.expression(exp.ColumnDef, this=this, ordinality=True) + constraints = [] while True: constraint = self._parse_column_constraint() @@ -2738,79 +2864,78 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - def _parse_column_constraint(self) -> t.Optional[exp.Expression]: - this = self._parse_references() + def _parse_auto_increment(self) -> exp.Expression: + start = None + increment = None - if this: - return this + if self._match(TokenType.L_PAREN, advance=False): + args = self._parse_wrapped_csv(self._parse_bitwise) + start = seq_get(args, 0) + increment = seq_get(args, 1) + elif self._match_text_seq("START"): + start = self._parse_bitwise() + self._match_text_seq("INCREMENT") + increment = self._parse_bitwise() + + if start and increment: + return exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment) + + return exp.AutoIncrementColumnConstraint() + + def _parse_generated_as_identity(self) -> exp.Expression: + if self._match(TokenType.BY_DEFAULT): + this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) + else: + self._match_text_seq("ALWAYS") + this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) - if self._match(TokenType.CONSTRAINT): - this = self._parse_id_var() + self._match_text_seq("AS", "IDENTITY") + if self._match(TokenType.L_PAREN): + if self._match_text_seq("START", "WITH"): + this.set("start", self._parse_bitwise()) + if self._match_text_seq("INCREMENT", "BY"): + this.set("increment", self._parse_bitwise()) + if self._match_text_seq("MINVALUE"): + this.set("minvalue", self._parse_bitwise()) + if self._match_text_seq("MAXVALUE"): + this.set("maxvalue", self._parse_bitwise()) + + if self._match_text_seq("CYCLE"): + this.set("cycle", True) + elif self._match_text_seq("NO", "CYCLE"): + this.set("cycle", False) - kind: exp.Expression + self._match_r_paren() - if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)): - start = None - increment = None + return this - if self._match(TokenType.L_PAREN, advance=False): - args = self._parse_wrapped_csv(self._parse_bitwise) - start = seq_get(args, 0) - increment = seq_get(args, 1) - elif self._match_text_seq("START"): - start = self._parse_bitwise() - self._match_text_seq("INCREMENT") - increment = self._parse_bitwise() + def _parse_not_constraint(self) -> t.Optional[exp.Expression]: + if self._match_text_seq("NULL"): + return self.expression(exp.NotNullColumnConstraint) + if self._match_text_seq("CASESPECIFIC"): + return self.expression(exp.CaseSpecificColumnConstraint, not_=True) + return None - if start and increment: - kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment) - else: - kind = exp.AutoIncrementColumnConstraint() - elif self._match(TokenType.CHECK): - constraint = self._parse_wrapped(self._parse_conjunction) - kind = self.expression(exp.CheckColumnConstraint, this=constraint) - elif self._match(TokenType.COLLATE): - kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) - elif self._match(TokenType.ENCODE): - kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) - elif self._match(TokenType.DEFAULT): - kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise()) - elif self._match_pair(TokenType.NOT, TokenType.NULL): - kind = exp.NotNullColumnConstraint() - elif self._match(TokenType.NULL): - kind = exp.NotNullColumnConstraint(allow_null=True) - elif self._match(TokenType.SCHEMA_COMMENT): - kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) - elif self._match(TokenType.PRIMARY_KEY): - desc = None - if self._match(TokenType.ASC) or self._match(TokenType.DESC): - desc = self._prev.token_type == TokenType.DESC - kind = exp.PrimaryKeyColumnConstraint(desc=desc) - elif self._match(TokenType.UNIQUE): - kind = exp.UniqueColumnConstraint() - elif self._match(TokenType.GENERATED): - if self._match(TokenType.BY_DEFAULT): - kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) - else: - self._match(TokenType.ALWAYS) - kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) - self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) + def _parse_column_constraint(self) -> t.Optional[exp.Expression]: + this = self._parse_references() + if this: + return this - if self._match(TokenType.L_PAREN): - if self._match_text_seq("START", "WITH"): - kind.set("start", self._parse_bitwise()) - if self._match_text_seq("INCREMENT", "BY"): - kind.set("increment", self._parse_bitwise()) + if self._match(TokenType.CONSTRAINT): + this = self._parse_id_var() - self._match_r_paren() - else: - return this + if self._match_texts(self.CONSTRAINT_PARSERS): + return self.expression( + exp.ColumnConstraint, + this=this, + kind=self.CONSTRAINT_PARSERS[self._prev.text.upper()](self), + ) - return self.expression(exp.ColumnConstraint, this=this, kind=kind) + return this def _parse_constraint(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.CONSTRAINT): - return self._parse_unnamed_constraint() + return self._parse_unnamed_constraint(constraints=self.SCHEMA_UNNAMED_CONSTRAINTS) this = self._parse_id_var() expressions = [] @@ -2823,12 +2948,21 @@ class Parser(metaclass=_Parser): return self.expression(exp.Constraint, this=this, expressions=expressions) - def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]: - if not self._match_set(self.CONSTRAINT_PARSERS): + def _parse_unnamed_constraint( + self, constraints: t.Optional[t.Collection[str]] = None + ) -> t.Optional[exp.Expression]: + if not self._match_texts(constraints or self.CONSTRAINT_PARSERS): return None - return self.CONSTRAINT_PARSERS[self._prev.token_type](self) + + constraint = self._prev.text.upper() + if constraint not in self.CONSTRAINT_PARSERS: + self.raise_error(f"No parser found for schema constraint {constraint}.") + + return self.CONSTRAINT_PARSERS[constraint](self) def _parse_unique(self) -> exp.Expression: + if not self._match(TokenType.L_PAREN, advance=False): + return self.expression(exp.UniqueColumnConstraint) return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) def _parse_key_constraint_options(self) -> t.List[str]: @@ -2908,6 +3042,14 @@ class Parser(metaclass=_Parser): ) def _parse_primary_key(self) -> exp.Expression: + desc = ( + self._match_set((TokenType.ASC, TokenType.DESC)) + and self._prev.token_type == TokenType.DESC + ) + + if not self._match(TokenType.L_PAREN, advance=False): + return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc) + expressions = self._parse_wrapped_id_vars() options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) @@ -3306,6 +3448,12 @@ class Parser(metaclass=_Parser): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None + def _parse_parameter(self) -> exp.Expression: + wrapped = self._match(TokenType.L_BRACE) + this = self._parse_var() or self._parse_primary() + self._match(TokenType.R_BRACE) + return self.expression(exp.Parameter, this=this, wrapped=wrapped) + def _parse_placeholder(self) -> t.Optional[exp.Expression]: if self._match_set(self.PLACEHOLDER_PARSERS): placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self) @@ -3449,7 +3597,7 @@ class Parser(metaclass=_Parser): if kind == TokenType.CONSTRAINT: this = self._parse_id_var() - if self._match(TokenType.CHECK): + if self._match_text_seq("CHECK"): expression = self._parse_wrapped(self._parse_conjunction) enforced = self._match_text_seq("ENFORCED") diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8cf17a7..9b29c12 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -138,7 +138,6 @@ class TokenType(AutoName): CASCADE = auto() CASE = auto() CHARACTER_SET = auto() - CHECK = auto() CLUSTER_BY = auto() COLLATE = auto() COMMAND = auto() @@ -164,7 +163,6 @@ class TokenType(AutoName): DIV = auto() DROP = auto() ELSE = auto() - ENCODE = auto() END = auto() ESCAPE = auto() EXCEPT = auto() @@ -182,17 +180,16 @@ class TokenType(AutoName): FROM = auto() FULL = auto() FUNCTION = auto() - GENERATED = auto() GLOB = auto() GLOBAL = auto() GROUP_BY = auto() GROUPING_SETS = auto() HAVING = auto() HINT = auto() - IDENTITY = auto() IF = auto() IGNORE_NULLS = auto() ILIKE = auto() + ILIKE_ANY = auto() IN = auto() INDEX = auto() INNER = auto() @@ -211,6 +208,7 @@ class TokenType(AutoName): LEADING = auto() LEFT = auto() LIKE = auto() + LIKE_ANY = auto() LIMIT = auto() LOAD_DATA = auto() LOCAL = auto() @@ -253,6 +251,7 @@ class TokenType(AutoName): RECURSIVE = auto() REPLACE = auto() RESPECT_NULLS = auto() + RETURNING = auto() REFERENCES = auto() RIGHT = auto() RLIKE = auto() @@ -260,7 +259,6 @@ class TokenType(AutoName): ROLLUP = auto() ROW = auto() ROWS = auto() - SCHEMA_COMMENT = auto() SEED = auto() SELECT = auto() SEMI = auto() @@ -441,7 +439,7 @@ class Tokenizer(metaclass=_Tokenizer): KEYWORDS = { **{ f"{key}{postfix}": TokenType.BLOCK_START - for key in ("{{", "{%", "{#") + for key in ("{%", "{#") for postfix in ("", "+", "-") }, **{ @@ -449,6 +447,8 @@ class Tokenizer(metaclass=_Tokenizer): for key in ("%}", "#}") for prefix in ("", "+", "-") }, + "{{+": TokenType.BLOCK_START, + "{{-": TokenType.BLOCK_START, "+}}": TokenType.BLOCK_END, "-}}": TokenType.BLOCK_END, "/*+": TokenType.HINT, @@ -486,11 +486,9 @@ class Tokenizer(metaclass=_Tokenizer): "CASE": TokenType.CASE, "CASCADE": TokenType.CASCADE, "CHARACTER SET": TokenType.CHARACTER_SET, - "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, "COLLATE": TokenType.COLLATE, "COLUMN": TokenType.COLUMN, - "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, "COMPOUND": TokenType.COMPOUND, "CONSTRAINT": TokenType.CONSTRAINT, @@ -526,12 +524,10 @@ class Tokenizer(metaclass=_Tokenizer): "FOREIGN KEY": TokenType.FOREIGN_KEY, "FORMAT": TokenType.FORMAT, "FROM": TokenType.FROM, - "GENERATED": TokenType.GENERATED, "GLOB": TokenType.GLOB, "GROUP BY": TokenType.GROUP_BY, "GROUPING SETS": TokenType.GROUPING_SETS, "HAVING": TokenType.HAVING, - "IDENTITY": TokenType.IDENTITY, "IF": TokenType.IF, "ILIKE": TokenType.ILIKE, "IGNORE NULLS": TokenType.IGNORE_NULLS, @@ -747,11 +743,9 @@ class Tokenizer(metaclass=_Tokenizer): "_prev_token_line", "_prev_token_comments", "_prev_token_type", - "_replace_backslash", ) def __init__(self) -> None: - self._replace_backslash = "\\" in self._STRING_ESCAPES self.reset() def reset(self) -> None: @@ -855,7 +849,7 @@ class Tokenizer(metaclass=_Tokenizer): def _scan_keywords(self) -> None: size = 0 word = None - chars = self._text + chars: t.Optional[str] = self._text char = chars prev_space = False skip = False @@ -887,7 +881,7 @@ class Tokenizer(metaclass=_Tokenizer): else: skip = True else: - chars = None # type: ignore + chars = None if not word: if self._char in self.SINGLE_TOKENS: @@ -1015,7 +1009,6 @@ class Tokenizer(metaclass=_Tokenizer): self._advance(len(quote)) text = self._extract_string(quote_end) text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore - text = text.replace("\\\\", "\\") if self._replace_backslash else text self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) return True @@ -1091,13 +1084,18 @@ class Tokenizer(metaclass=_Tokenizer): delim_size = len(delimiter) while True: - if ( - self._char in self._STRING_ESCAPES - and self._peek - and (self._peek == delimiter or self._peek in self._STRING_ESCAPES) + if self._char in self._STRING_ESCAPES and ( + self._peek == delimiter or self._peek in self._STRING_ESCAPES ): - text += self._peek - self._advance(2) + if self._peek == delimiter: + text += self._peek # type: ignore + else: + text += self._char + self._peek # type: ignore + + if self._current + 1 < self.size: + self._advance(2) + else: + raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._current}") else: if self._chars(delim_size) == delimiter: if delim_size > 1: -- cgit v1.2.3