summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-19 13:45:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-19 13:45:09 +0000
commit639a208fa57ea674d165c4837e96f3ae4d7e3e61 (patch)
treef4d66da146c396d407cecefb5b405e609af1109e /sqlglot
parentReleasing debian version 11.0.1-1. (diff)
downloadsqlglot-639a208fa57ea674d165c4837e96f3ae4d7e3e61.tar.xz
sqlglot-639a208fa57ea674d165c4837e96f3ae4d7e3e61.zip
Merging upstream version 11.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/column.py1
-rw-r--r--sqlglot/dataframe/sql/functions.py235
-rw-r--r--sqlglot/dialects/bigquery.py30
-rw-r--r--sqlglot/dialects/databricks.py9
-rw-r--r--sqlglot/dialects/dialect.py22
-rw-r--r--sqlglot/dialects/drill.py21
-rw-r--r--sqlglot/dialects/duckdb.py20
-rw-r--r--sqlglot/dialects/hive.py14
-rw-r--r--sqlglot/dialects/mysql.py12
-rw-r--r--sqlglot/dialects/oracle.py62
-rw-r--r--sqlglot/dialects/postgres.py17
-rw-r--r--sqlglot/dialects/presto.py7
-rw-r--r--sqlglot/dialects/redshift.py12
-rw-r--r--sqlglot/dialects/snowflake.py36
-rw-r--r--sqlglot/dialects/spark.py17
-rw-r--r--sqlglot/dialects/sqlite.py2
-rw-r--r--sqlglot/dialects/teradata.py45
-rw-r--r--sqlglot/dialects/tsql.py6
-rw-r--r--sqlglot/executor/__init__.py4
-rw-r--r--sqlglot/executor/python.py2
-rw-r--r--sqlglot/expressions.py233
-rw-r--r--sqlglot/generator.py264
-rw-r--r--sqlglot/optimizer/annotate_types.py3
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py6
-rw-r--r--sqlglot/optimizer/optimizer.py31
-rw-r--r--sqlglot/optimizer/pushdown_projections.py76
-rw-r--r--sqlglot/optimizer/qualify_columns.py48
-rw-r--r--sqlglot/optimizer/qualify_tables.py2
-rw-r--r--sqlglot/optimizer/scope.py169
-rw-r--r--sqlglot/parser.py390
-rw-r--r--sqlglot/tokens.py40
32 files changed, 1213 insertions, 625 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 7b07ae1..c17a703 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -40,7 +40,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.0.1"
+__version__ = "11.1.3"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index f5b0974..609b2a4 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -143,6 +143,7 @@ class Column:
if is_iterable(v)
else Column.ensure_col(v).expression
for k, v in kwargs.items()
+ if v is not None
}
new_expression = (
callable_expression(**ensure_expression_values)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 47d5e7b..0262d54 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import expressions as glotexp
+from sqlglot import exp as expression
from sqlglot.dataframe.sql.column import Column
from sqlglot.helper import ensure_list
from sqlglot.helper import flatten as _flatten
@@ -18,25 +18,29 @@ def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
def lit(value: t.Optional[t.Any] = None) -> Column:
if isinstance(value, str):
- return Column(glotexp.Literal.string(str(value)))
+ return Column(expression.Literal.string(str(value)))
return Column(value)
def greatest(*cols: ColumnOrName) -> Column:
if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], glotexp.Greatest)
+ return Column.invoke_expression_over_column(
+ cols[0], expression.Greatest, expressions=cols[1:]
+ )
+ return Column.invoke_expression_over_column(cols[0], expression.Greatest)
def least(*cols: ColumnOrName) -> Column:
if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], glotexp.Least)
+ return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:])
+ return Column.invoke_expression_over_column(cols[0], expression.Least)
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
columns = [Column.ensure_col(x) for x in [col] + list(cols)]
- return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns])))
+ return Column(
+ expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns]))
+ )
def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
@@ -46,8 +50,8 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value)
return Column(
- glotexp.Case(
- ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
+ expression.Case(
+ ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)]
)
)
@@ -65,19 +69,19 @@ def broadcast(df: DataFrame) -> DataFrame:
def sqrt(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Sqrt)
+ return Column.invoke_expression_over_column(col, expression.Sqrt)
def abs(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Abs)
+ return Column.invoke_expression_over_column(col, expression.Abs)
def max(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Max)
+ return Column.invoke_expression_over_column(col, expression.Max)
def min(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Min)
+ return Column.invoke_expression_over_column(col, expression.Min)
def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
@@ -89,15 +93,15 @@ def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
def count(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Count)
+ return Column.invoke_expression_over_column(col, expression.Count)
def sum(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Sum)
+ return Column.invoke_expression_over_column(col, expression.Sum)
def avg(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Avg)
+ return Column.invoke_expression_over_column(col, expression.Avg)
def mean(col: ColumnOrName) -> Column:
@@ -149,7 +153,7 @@ def cbrt(col: ColumnOrName) -> Column:
def ceil(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Ceil)
+ return Column.invoke_expression_over_column(col, expression.Ceil)
def cos(col: ColumnOrName) -> Column:
@@ -169,7 +173,7 @@ def csc(col: ColumnOrName) -> Column:
def exp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Exp)
+ return Column.invoke_expression_over_column(col, expression.Exp)
def expm1(col: ColumnOrName) -> Column:
@@ -177,11 +181,11 @@ def expm1(col: ColumnOrName) -> Column:
def floor(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Floor)
+ return Column.invoke_expression_over_column(col, expression.Floor)
def log10(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Log10)
+ return Column.invoke_expression_over_column(col, expression.Log10)
def log1p(col: ColumnOrName) -> Column:
@@ -189,13 +193,13 @@ def log1p(col: ColumnOrName) -> Column:
def log2(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Log2)
+ return Column.invoke_expression_over_column(col, expression.Log2)
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
if arg2 is None:
- return Column.invoke_expression_over_column(arg1, glotexp.Ln)
- return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2)
+ return Column.invoke_expression_over_column(arg1, expression.Ln)
+ return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2)
def rint(col: ColumnOrName) -> Column:
@@ -247,7 +251,7 @@ def bitwiseNOT(col: ColumnOrName) -> Column:
def bitwise_not(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.BitwiseNot)
+ return Column.invoke_expression_over_column(col, expression.BitwiseNot)
def asc_nulls_first(col: ColumnOrName) -> Column:
@@ -267,27 +271,27 @@ def desc_nulls_last(col: ColumnOrName) -> Column:
def stddev(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Stddev)
+ return Column.invoke_expression_over_column(col, expression.Stddev)
def stddev_samp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.StddevSamp)
+ return Column.invoke_expression_over_column(col, expression.StddevSamp)
def stddev_pop(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.StddevPop)
+ return Column.invoke_expression_over_column(col, expression.StddevPop)
def variance(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Variance)
+ return Column.invoke_expression_over_column(col, expression.Variance)
def var_samp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Variance)
+ return Column.invoke_expression_over_column(col, expression.Variance)
def var_pop(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.VariancePop)
+ return Column.invoke_expression_over_column(col, expression.VariancePop)
def skewness(col: ColumnOrName) -> Column:
@@ -299,11 +303,11 @@ def kurtosis(col: ColumnOrName) -> Column:
def collect_list(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.ArrayAgg)
+ return Column.invoke_expression_over_column(col, expression.ArrayAgg)
def collect_set(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.SetAgg)
+ return Column.invoke_expression_over_column(col, expression.SetAgg)
def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
@@ -311,27 +315,27 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
- return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2)
+ return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2)
def row_number() -> Column:
- return Column(glotexp.Anonymous(this="ROW_NUMBER"))
+ return Column(expression.Anonymous(this="ROW_NUMBER"))
def dense_rank() -> Column:
- return Column(glotexp.Anonymous(this="DENSE_RANK"))
+ return Column(expression.Anonymous(this="DENSE_RANK"))
def rank() -> Column:
- return Column(glotexp.Anonymous(this="RANK"))
+ return Column(expression.Anonymous(this="RANK"))
def cume_dist() -> Column:
- return Column(glotexp.Anonymous(this="CUME_DIST"))
+ return Column(expression.Anonymous(this="CUME_DIST"))
def percent_rank() -> Column:
- return Column(glotexp.Anonymous(this="PERCENT_RANK"))
+ return Column(expression.Anonymous(this="PERCENT_RANK"))
def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
@@ -340,14 +344,16 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col
def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
if rsd is None:
- return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
- return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd)
+ return Column.invoke_expression_over_column(col, expression.ApproxDistinct)
+ return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd)
def coalesce(*cols: ColumnOrName) -> Column:
if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce)
+ return Column.invoke_expression_over_column(
+ cols[0], expression.Coalesce, expressions=cols[1:]
+ )
+ return Column.invoke_expression_over_column(cols[0], expression.Coalesce)
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
@@ -409,10 +415,10 @@ def percentile_approx(
) -> Column:
if accuracy:
return Column.invoke_expression_over_column(
- col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
+ col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
return Column.invoke_expression_over_column(
- col, glotexp.ApproxQuantile, quantile=lit(percentage)
+ col, expression.ApproxQuantile, quantile=lit(percentage)
)
@@ -426,8 +432,8 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
if scale is not None:
- return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale)
- return Column.invoke_expression_over_column(col, glotexp.Round)
+ return Column.invoke_expression_over_column(col, expression.Round, decimals=scale)
+ return Column.invoke_expression_over_column(col, expression.Round)
def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
@@ -437,7 +443,9 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
def shiftleft(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits)
+ return Column.invoke_expression_over_column(
+ col, expression.BitwiseLeftShift, expression=numBits
+ )
def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
@@ -445,7 +453,9 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
def shiftright(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits)
+ return Column.invoke_expression_over_column(
+ col, expression.BitwiseRightShift, expression=numBits
+ )
def shiftRight(col: ColumnOrName, numBits: int) -> Column:
@@ -466,7 +476,7 @@ def expr(str: str) -> Column:
def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
columns = ensure_list(col) + list(cols)
- return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns)
+ return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
@@ -512,19 +522,19 @@ def ntile(n: int) -> Column:
def current_date() -> Column:
- return Column.invoke_expression_over_column(None, glotexp.CurrentDate)
+ return Column.invoke_expression_over_column(None, expression.CurrentDate)
def current_timestamp() -> Column:
- return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp)
+ return Column.invoke_expression_over_column(None, expression.CurrentTimestamp)
def date_format(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format))
+ return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format))
def year(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Year)
+ return Column.invoke_expression_over_column(col, expression.Year)
def quarter(col: ColumnOrName) -> Column:
@@ -532,19 +542,19 @@ def quarter(col: ColumnOrName) -> Column:
def month(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Month)
+ return Column.invoke_expression_over_column(col, expression.Month)
def dayofweek(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
+ return Column.invoke_expression_over_column(col, expression.DayOfWeek)
def dayofmonth(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
+ return Column.invoke_expression_over_column(col, expression.DayOfMonth)
def dayofyear(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
+ return Column.invoke_expression_over_column(col, expression.DayOfYear)
def hour(col: ColumnOrName) -> Column:
@@ -560,7 +570,7 @@ def second(col: ColumnOrName) -> Column:
def weekofyear(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
+ return Column.invoke_expression_over_column(col, expression.WeekOfYear)
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
@@ -568,15 +578,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days)
+ return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days)
+ return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start)
+ return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start)
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
@@ -593,8 +603,10 @@ def months_between(
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format))
- return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate)
+ return Column.invoke_expression_over_column(
+ col, expression.TsOrDsToDate, format=lit(format)
+ )
+ return Column.invoke_expression_over_column(col, expression.TsOrDsToDate)
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
@@ -604,11 +616,13 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
def trunc(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format))
+ return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format))
def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format))
+ return Column.invoke_expression_over_column(
+ timestamp, expression.TimestampTrunc, unit=lit(format)
+ )
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
@@ -621,8 +635,8 @@ def last_day(col: ColumnOrName) -> Column:
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format))
- return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
+ return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format))
+ return Column.invoke_expression_over_column(col, expression.UnixToStr)
def unix_timestamp(
@@ -630,9 +644,9 @@ def unix_timestamp(
) -> Column:
if format is not None:
return Column.invoke_expression_over_column(
- timestamp, glotexp.StrToUnix, format=lit(format)
+ timestamp, expression.StrToUnix, format=lit(format)
)
- return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
+ return Column.invoke_expression_over_column(timestamp, expression.StrToUnix)
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
@@ -719,11 +733,11 @@ def raise_error(errorMsg: ColumnOrName) -> Column:
def upper(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Upper)
+ return Column.invoke_expression_over_column(col, expression.Upper)
def lower(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Lower)
+ return Column.invoke_expression_over_column(col, expression.Lower)
def ascii(col: ColumnOrLiteral) -> Column:
@@ -747,24 +761,24 @@ def rtrim(col: ColumnOrName) -> Column:
def trim(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Trim)
+ return Column.invoke_expression_over_column(col, expression.Trim)
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(
- None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
+ None, expression.ConcatWs, expressions=[lit(sep)] + list(cols)
)
def decode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_expression_over_column(
- col, glotexp.Decode, charset=glotexp.Literal.string(charset)
+ col, expression.Decode, charset=expression.Literal.string(charset)
)
def encode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_expression_over_column(
- col, glotexp.Encode, charset=glotexp.Literal.string(charset)
+ col, expression.Encode, charset=expression.Literal.string(charset)
)
@@ -816,16 +830,16 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right)
+ return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right)
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
if pos is not None:
return Column.invoke_expression_over_column(
- str, glotexp.StrPosition, substr=substr_col, position=pos
+ str, expression.StrPosition, substr=substr_col, position=pos
)
- return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
+ return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col)
def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
@@ -837,21 +851,26 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
def repeat(col: ColumnOrName, n: int) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n))
+ return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n))
def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
if limit is not None:
return Column.invoke_expression_over_column(
- str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit
+ str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit
)
- return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern))
+ return Column.invoke_expression_over_column(
+ str, expression.RegexpSplit, expression=lit(pattern)
+ )
def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
- if idx is not None:
- return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx)
- return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern))
+ return Column.invoke_expression_over_column(
+ str,
+ expression.RegexpExtract,
+ expression=lit(pattern),
+ group=idx,
+ )
def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
@@ -859,7 +878,7 @@ def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
def initcap(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Initcap)
+ return Column.invoke_expression_over_column(col, expression.Initcap)
def soundex(col: ColumnOrName) -> Column:
@@ -871,15 +890,15 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Hex)
+ return Column.invoke_expression_over_column(col, expression.Hex)
def unhex(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Unhex)
+ return Column.invoke_expression_over_column(col, expression.Unhex)
def length(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Length)
+ return Column.invoke_expression_over_column(col, expression.Length)
def octet_length(col: ColumnOrName) -> Column:
@@ -896,27 +915,27 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
- return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns)
+ return Column.invoke_expression_over_column(None, expression.Array, expressions=columns)
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column(
None,
- glotexp.VarMap,
+ expression.VarMap,
keys=array(*cols[::2]).expression,
values=array(*cols[1::2]).expression,
)
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2)
+ return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2)
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value)
return Column.invoke_expression_over_column(
- col, glotexp.ArrayContains, expression=value_col.expression
+ col, expression.ArrayContains, expression=value_col.expression
)
@@ -943,7 +962,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
+ return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@@ -978,11 +997,11 @@ def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def explode(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Explode)
+ return Column.invoke_expression_over_column(col, expression.Explode)
def posexplode(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Posexplode)
+ return Column.invoke_expression_over_column(col, expression.Posexplode)
def explode_outer(col: ColumnOrName) -> Column:
@@ -994,7 +1013,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
def get_json_object(col: ColumnOrName, path: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path))
+ return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path))
def json_tuple(col: ColumnOrName, *fields: str) -> Column:
@@ -1042,7 +1061,7 @@ def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> C
def size(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.ArraySize)
+ return Column.invoke_expression_over_column(col, expression.ArraySize)
def array_min(col: ColumnOrName) -> Column:
@@ -1055,8 +1074,8 @@ def array_max(col: ColumnOrName) -> Column:
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
if asc is not None:
- return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc)
- return Column.invoke_expression_over_column(col, glotexp.SortArray)
+ return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc)
+ return Column.invoke_expression_over_column(col, expression.SortArray)
def array_sort(
@@ -1065,8 +1084,10 @@ def array_sort(
) -> Column:
if comparator is not None:
f_expression = _get_lambda_from_func(comparator)
- return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression)
- return Column.invoke_expression_over_column(col, glotexp.ArraySort)
+ return Column.invoke_expression_over_column(
+ col, expression.ArraySort, expression=f_expression
+ )
+ return Column.invoke_expression_over_column(col, expression.ArraySort)
def shuffle(col: ColumnOrName) -> Column:
@@ -1146,13 +1167,13 @@ def aggregate(
finish_exp = _get_lambda_from_func(finish)
return Column.invoke_expression_over_column(
col,
- glotexp.Reduce,
+ expression.Reduce,
initial=initialValue,
merge=Column(merge_exp),
finish=Column(finish_exp),
)
return Column.invoke_expression_over_column(
- col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
+ col, expression.Reduce, initial=initialValue, merge=Column(merge_exp)
)
@@ -1179,7 +1200,9 @@ def filter(
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
- return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
+ return Column.invoke_expression_over_column(
+ col, expression.ArrayFilter, expression=f_expression
+ )
def zip_with(
@@ -1219,10 +1242,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable):
variables = [
- glotexp.to_identifier(x, quoted=_lambda_quoted(x))
+ expression.to_identifier(x, quoted=_lambda_quoted(x))
for x in lambda_expression.__code__.co_varnames
]
- return glotexp.Lambda(
+ return expression.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,
)
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 6a19b46..7fd9e35 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import re
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
@@ -31,13 +32,6 @@ def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
return func
-def _date_trunc(args: t.Sequence) -> exp.Expression:
- unit = seq_get(args, 1)
- if isinstance(unit, exp.Column):
- unit = exp.Var(this=unit.name)
- return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
-
-
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
@@ -158,11 +152,23 @@ class BigQuery(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
- "DATE_TRUNC": _date_trunc,
+ "DATE_TRUNC": lambda args: exp.DateTrunc(
+ unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
+ this=seq_get(args, 0),
+ ),
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
+ "REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ position=seq_get(args, 2),
+ occurrence=seq_get(args, 3),
+ group=exp.Literal.number(1)
+ if re.compile(str(seq_get(args, 1))).groups == 1
+ else None,
+ ),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
@@ -214,6 +220,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
+ exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
@@ -226,11 +233,12 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
- exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
+ exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@@ -251,6 +259,10 @@ class BigQuery(Dialect):
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.NVARCHAR: "STRING",
}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
+ }
EXPLICIT_UNION = True
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 2498c62..2e058e8 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -4,6 +4,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
+from sqlglot.tokens import TokenType
class Databricks(Spark):
@@ -21,3 +22,11 @@ class Databricks(Spark):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}
+
+ PARAMETER_TOKEN = "$"
+
+ class Tokenizer(Spark.Tokenizer):
+ SINGLE_TOKENS = {
+ **Spark.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.PARAMETER,
+ }
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 176a8ce..f4e8fd4 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -215,24 +215,19 @@ DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
- def _rename(self, expression):
- args = flatten(expression.args.values())
- return f"{self.normalize_func(name)}({self.format_args(*args)})"
-
- return _rename
+ return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
- return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
+ return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql(self: Generator, expression: exp.If) -> str:
- expressions = self.format_args(
- expression.this, expression.args.get("true"), expression.args.get("false")
+ return self.func(
+ "IF", expression.this, expression.args.get("true"), expression.args.get("false")
)
- return f"IF({expressions})"
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
@@ -318,13 +313,13 @@ def var_map_sql(
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
- return f"{map_func_name}({self.format_args(keys, values)})"
+ return self.func(map_func_name, keys, values)
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
- return f"{map_func_name}({self.format_args(*args)})"
+ return self.func(map_func_name, *args)
def format_time_lambda(
@@ -400,10 +395,9 @@ def locate_to_strposition(args: t.Sequence) -> exp.Expression:
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
- args = self.format_args(
- expression.args.get("substr"), expression.this, expression.args.get("position")
+ return self.func(
+ "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
)
- return f"LOCATE({args})"
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 1730eaf..e9c42e1 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -39,23 +39,6 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
-def if_sql(self: generator.Generator, expression: exp.If) -> str:
- """
- Drill requires backticks around certain SQL reserved words, IF being one of them, This function
- adds the backticks around the keyword IF.
- Args:
- self: The Drill dialect
- expression: The input IF expression
-
- Returns: The expression with IF in backticks.
-
- """
- expressions = self.format_args(
- expression.this, expression.args.get("true"), expression.args.get("false")
- )
- return f"`IF`({expressions})"
-
-
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@@ -134,7 +117,7 @@ class Drill(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
}
TRANSFORMS = {
@@ -148,7 +131,7 @@ class Drill(Dialect):
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
- exp.If: if_sql,
+ exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 959e5e2..cfec9a4 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -73,11 +73,24 @@ def _datatype_sql(self, expression):
return self.datatype_sql(expression)
+def _regexp_extract_sql(self, expression):
+ bad_args = list(filter(expression.args.get, ("position", "occurrence")))
+ if bad_args:
+ self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
+ return self.func(
+ "REGEXP_EXTRACT",
+ expression.args.get("this"),
+ expression.args.get("expression"),
+ expression.args.get("group"),
+ )
+
+
class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
+ "ATTACH": TokenType.COMMAND,
"CHARACTER VARYING": TokenType.VARCHAR,
}
@@ -117,7 +130,7 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
+ exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@@ -125,7 +138,9 @@ class DuckDB(Dialect):
exp.ArraySum: rename_func("LIST_SUM"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
- exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
+ ),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
@@ -137,6 +152,7 @@ class DuckDB(Dialect):
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
+ exp.RegexpExtract: _regexp_extract_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index c558b70..ea1191e 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -43,7 +43,7 @@ def _add_date_sql(self, expression):
else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
- return f"{func}({self.format_args(expression.this, modified_increment.this)})"
+ return self.func(func, expression.this, modified_increment.this)
def _date_diff_sql(self, expression):
@@ -66,7 +66,7 @@ def _property_sql(self, expression):
def _str_to_unix(self, expression):
- return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})"
+ return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
def _str_to_date(self, expression):
@@ -312,7 +312,9 @@ class Hive(Dialect):
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,
exp.TryCast: no_trycast_sql,
- exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
+ exp.UnixToStr: lambda self, e: self.func(
+ "FROM_UNIXTIME", e.this, _time_format(self, e)
+ ),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
@@ -324,9 +326,9 @@ class Hive(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
def with_properties(self, properties):
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index c2c2c8c..235eb77 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
+ rename_func,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
@@ -22,9 +23,8 @@ def _show_parser(*args, **kwargs):
def _date_trunc_sql(self, expression):
- unit = expression.name.lower()
-
- expr = self.sql(expression.expression)
+ expr = self.sql(expression, "this")
+ unit = expression.text("unit")
if unit == "day":
return f"DATE({expr})"
@@ -42,7 +42,7 @@ def _date_trunc_sql(self, expression):
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
- self.unsupported("Unexpected interval unit: {unit}")
+ self.unsupported(f"Unexpected interval unit: {unit}")
return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')"
@@ -443,6 +443,10 @@ class MySQL(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index fde845e..74baa8a 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -1,15 +1,49 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
+PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
+ TokenType.COLUMN,
+ TokenType.RETURNING,
+}
+
def _limit_sql(self, expression):
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
+def _parse_xml_table(self) -> exp.XMLTable:
+ this = self._parse_string()
+
+ passing = None
+ columns = None
+
+ if self._match_text_seq("PASSING"):
+ # The BY VALUE keywords are optional and are provided for semantic clarity
+ self._match_text_seq("BY", "VALUE")
+ passing = self._parse_csv(
+ lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS)
+ )
+
+ by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
+
+ if self._match_text_seq("COLUMNS"):
+ columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
+
+ return self.expression(
+ exp.XMLTable,
+ this=this,
+ passing=passing,
+ columns=columns,
+ by_ref=by_ref,
+ )
+
+
class Oracle(Dialect):
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
@@ -43,6 +77,11 @@ class Oracle(Dialect):
"DECODE": exp.Matches.from_arg_list,
}
+ FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
+ **parser.Parser.FUNCTION_PARSERS,
+ "XMLTABLE": _parse_xml_table,
+ }
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@@ -74,7 +113,7 @@ class Oracle(Dialect):
exp.Substring: rename_func("SUBSTR"),
}
- def query_modifiers(self, expression, *sqls):
+ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
@@ -97,19 +136,32 @@ class Oracle(Dialect):
sep="",
)
- def offset_sql(self, expression):
+ def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
- def table_sql(self, expression):
- return super().table_sql(expression, sep=" ")
+ def table_sql(self, expression: exp.Table, sep: str = " ") -> str:
+ return super().table_sql(expression, sep=sep)
+
+ def xmltable_sql(self, expression: exp.XMLTable) -> str:
+ this = self.sql(expression, "this")
+ passing = self.expressions(expression, "passing")
+ passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else ""
+ columns = self.expressions(expression, "columns")
+ columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else ""
+ by_ref = (
+ f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else ""
+ )
+ return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
+ "NVARCHAR2": TokenType.NVARCHAR,
+ "RETURNING": TokenType.RETURNING,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
- "NVARCHAR2": TokenType.NVARCHAR,
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c709665..7612330 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -58,17 +58,17 @@ def _date_diff_sql(self, expression):
age = f"AGE({end}, {start})"
if unit == "WEEK":
- extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
+ unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
elif unit == "MONTH":
- extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
+ unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
- extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
+ unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
elif unit == "YEAR":
- extract = f"EXTRACT(year FROM {age})"
+ unit = f"EXTRACT(year FROM {age})"
else:
- self.unsupported(f"Unsupported DATEDIFF unit {unit}")
+ unit = age
- return f"CAST({extract} AS BIGINT)"
+ return f"CAST({unit} AS BIGINT)"
def _substring_sql(self, expression):
@@ -206,6 +206,8 @@ class Postgres(Dialect):
}
class Tokenizer(tokens.Tokenizer):
+ QUOTES = ["'", "$$"]
+
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
@@ -236,7 +238,7 @@ class Postgres(Dialect):
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
}
- QUOTES = ["'", "$$"]
+
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
@@ -265,6 +267,7 @@ class Postgres(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
+ PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 6c1a474..aef9de3 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -52,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
- return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
+ return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
def _encode_sql(self, expression):
@@ -65,8 +65,7 @@ def _no_sort_array(self, expression):
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
- args = self.format_args(expression.this, comparator)
- return f"ARRAY_SORT({args})"
+ return self.func("ARRAY_SORT", expression.this, comparator)
def _schema_sql(self, expression):
@@ -125,7 +124,7 @@ def _sequence_sql(self, expression):
else:
start = exp.Cast(this=start, to=to)
- return f"SEQUENCE({self.format_args(start, end, step)})"
+ return self.func("SEQUENCE", start, end, step)
def _ensure_utf8(charset):
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 813ee5f..b4268e6 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
+from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -19,6 +20,11 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
+ "DATEDIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
+ ),
"DECODE": exp.Matches.from_arg_list,
"NVL": exp.Coalesce.from_arg_list,
}
@@ -41,7 +47,6 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
- "ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
@@ -62,12 +67,15 @@ class Redshift(Postgres):
PROPERTIES_LOCATION = {
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
+ exp.LikeProperty: exp.Properties.Location.POST_WITH,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
+ exp.DateDiff: lambda self, e: self.func(
+ "DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
+ ),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 55a6bd3..bb46135 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -178,18 +178,25 @@ class Snowflake(Dialect):
),
}
+ RANGE_PARSERS = {
+ **parser.Parser.RANGE_PARSERS, # type: ignore
+ TokenType.LIKE_ANY: lambda self, this: self._parse_escape(
+ self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise())
+ ),
+ TokenType.ILIKE_ANY: lambda self, this: self._parse_escape(
+ self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise())
+ ),
+ }
+
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
- SINGLE_TOKENS = {
- **tokens.Tokenizer.SINGLE_TOKENS,
- "$": TokenType.PARAMETER,
- }
-
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"EXCLUDE": TokenType.EXCEPT,
+ "ILIKE ANY": TokenType.ILIKE_ANY,
+ "LIKE ANY": TokenType.LIKE_ANY,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
@@ -201,8 +208,14 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
}
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.PARAMETER,
+ }
+
class Generator(generator.Generator):
CREATE_TRANSIENT = True
+ PARAMETER_TOKEN = "$"
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@@ -214,14 +227,15 @@ class Snowflake(Dialect):
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
- exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
- exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
+ exp.StrPosition: lambda self, e: self.func(
+ "POSITION", e.args.get("substr"), e.this, e.args.get("position")
+ ),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
- exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
+ exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
}
@@ -236,6 +250,12 @@ class Snowflake(Dialect):
"replace": "RENAME",
}
+ def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
+ return self.binary(expression, "ILIKE ANY")
+
+ def likeany_sql(self, expression: exp.LikeAny) -> str:
+ return self.binary(expression, "LIKE ANY")
+
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 03ec211..dd3e0c8 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -86,6 +86,11 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1),
+ unit=exp.var(seq_get(args, 0)),
+ ),
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
}
FUNCTION_PARSERS = {
@@ -133,7 +138,7 @@ class Spark(Hive):
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.DateTrunc: rename_func("TRUNC"),
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@@ -142,7 +147,9 @@ class Spark(Hive):
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
- exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
+ exp.TimestampTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
+ ),
exp.Trim: trim_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
@@ -157,16 +164,16 @@ class Spark(Hive):
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
- CREATE_FUNCTION_AS = False
+ CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
- return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
+ return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
- return f"TO_JSON({self.sql(expression, 'this')})"
+ return self.func("TO_JSON", expression.this)
return super(Spark.Generator, self).cast_sql(expression)
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index a428dd5..86603b5 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -39,7 +39,7 @@ def _date_add_sql(self, expression):
modifier = expression.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
- return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
+ return self.func("DATE", expression.this, modifier)
class SQLite(Dialect):
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 123da04..e3eec71 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -1,11 +1,33 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class Teradata(Dialect):
+ class Tokenizer(tokens.Tokenizer):
+ # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ "BYTEINT": TokenType.SMALLINT,
+ "SEL": TokenType.SELECT,
+ "INS": TokenType.INSERT,
+ "MOD": TokenType.MOD,
+ "LT": TokenType.LT,
+ "LE": TokenType.LTE,
+ "GT": TokenType.GT,
+ "GE": TokenType.GTE,
+ "^=": TokenType.NEQ,
+ "NE": TokenType.NEQ,
+ "NOT=": TokenType.NEQ,
+ "ST_GEOMETRY": TokenType.GEOMETRY,
+ }
+
+ # teradata does not support % for modulus
+ SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
+ SINGLE_TOKENS.pop("%")
+
class Parser(parser.Parser):
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
@@ -42,6 +64,14 @@ class Teradata(Dialect):
"UNICODE_TO_UNICODE_NFKD",
}
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS}
+ FUNC_TOKENS.remove(TokenType.REPLACE)
+
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ TokenType.REPLACE: lambda self: self._parse_create(),
+ }
+
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
@@ -76,6 +106,11 @@ class Teradata(Dialect):
)
class Generator(generator.Generator):
+ TYPE_MAPPING = {
+ **generator.Generator.TYPE_MAPPING, # type: ignore
+ exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
+ }
+
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
@@ -93,3 +128,11 @@ class Teradata(Dialect):
where_sql = self.sql(expression, "where")
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
+
+ def mod_sql(self, expression: exp.Mod) -> str:
+ return self.binary(expression, "MOD")
+
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ type_sql = super().datatype_sql(expression)
+ prefix_sql = expression.args.get("prefix")
+ return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 05ba53a..b9f932b 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -92,7 +92,7 @@ def _parse_eomonth(args):
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
- return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
+ return self.func(func, e.text("unit"), e.expression, e.this)
def _format_sql(self, e):
@@ -101,7 +101,7 @@ def _format_sql(self, e):
if isinstance(e, exp.NumberToStr)
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
)
- return f"FORMAT({self.format_args(e.this, fmt)})"
+ return self.func("FORMAT", e.this, fmt)
def _string_agg_sql(self, e):
@@ -408,7 +408,7 @@ class TSQL(Dialect):
):
return this
- expressions = self._parse_csv(self._parse_udf_kwarg)
+ expressions = self._parse_csv(self._parse_function_parameter)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
class Generator(generator.Generator):
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index c3d2701..a676e7d 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -62,10 +62,8 @@ def execute(
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
- expression = maybe_parse(sql, dialect=read)
-
now = time.time()
- expression = optimize(expression, schema, leave_tables_isolated=True)
+ expression = optimize(sql, schema, leave_tables_isolated=True, dialect=read)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index de570b0..d417328 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -376,7 +376,7 @@ def _rename(self, e):
this = self.sql(e, "this")
this = f"{this}, " if this else ""
return f"{e.key.upper()}({this}{self.expressions(e)})"
- return f"{e.key.upper()}({self.format_args(*e.args.values())})"
+ return self.func(e.key, *e.args.values())
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 6800cd5..42652a6 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -128,7 +128,7 @@ class Expression(metaclass=_Expression):
"""
return self.args.get("expressions") or []
- def text(self, key):
+ def text(self, key) -> str:
"""
Returns a textual representation of the argument corresponding to "key". This can only be used
for args that are strings or leaf Expression instances, such as identifiers and literals.
@@ -143,21 +143,21 @@ class Expression(metaclass=_Expression):
return ""
@property
- def is_string(self):
+ def is_string(self) -> bool:
"""
Checks whether a Literal expression is a string.
"""
return isinstance(self, Literal) and self.args["is_string"]
@property
- def is_number(self):
+ def is_number(self) -> bool:
"""
Checks whether a Literal expression is a number.
"""
return isinstance(self, Literal) and not self.args["is_string"]
@property
- def is_int(self):
+ def is_int(self) -> bool:
"""
Checks whether a Literal expression is an integer.
"""
@@ -170,7 +170,12 @@ class Expression(metaclass=_Expression):
return False
@property
- def alias(self):
+ def is_star(self) -> bool:
+ """Checks whether an expression is a star."""
+ return isinstance(self, Star) or (isinstance(self, Column) and isinstance(self.this, Star))
+
+ @property
+ def alias(self) -> str:
"""
Returns the alias of the expression, or an empty string if it's not aliased.
"""
@@ -825,10 +830,6 @@ class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False, "wrapped": False}
-class UserDefinedFunctionKwarg(Expression):
- arg_types = {"this": True, "kind": True, "default": False}
-
-
class CharacterSet(Expression):
arg_types = {"this": True, "default": False}
@@ -870,14 +871,22 @@ class ByteString(Condition):
class Column(Condition):
- arg_types = {"this": True, "table": False}
+ arg_types = {"this": True, "table": False, "db": False, "catalog": False}
@property
- def table(self):
+ def table(self) -> str:
return self.text("table")
@property
- def output_name(self):
+ def db(self) -> str:
+ return self.text("db")
+
+ @property
+ def catalog(self) -> str:
+ return self.text("catalog")
+
+ @property
+ def output_name(self) -> str:
return self.name
@@ -917,6 +926,14 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
+class CaseSpecificColumnConstraint(ColumnConstraintKind):
+ arg_types = {"not_": True}
+
+
+class CharacterSetColumnConstraint(ColumnConstraintKind):
+ arg_types = {"this": True}
+
+
class CheckColumnConstraint(ColumnConstraintKind):
pass
@@ -929,6 +946,10 @@ class CommentColumnConstraint(ColumnConstraintKind):
pass
+class DateFormatColumnConstraint(ColumnConstraintKind):
+ arg_types = {"this": True}
+
+
class DefaultColumnConstraint(ColumnConstraintKind):
pass
@@ -939,7 +960,14 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
- arg_types = {"this": False, "start": False, "increment": False}
+ arg_types = {
+ "this": False,
+ "start": False,
+ "increment": False,
+ "minvalue": False,
+ "maxvalue": False,
+ "cycle": False,
+ }
class NotNullColumnConstraint(ColumnConstraintKind):
@@ -950,7 +978,19 @@ class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
+class TitleColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class UniqueColumnConstraint(ColumnConstraintKind):
+ arg_types: t.Dict[str, t.Any] = {}
+
+
+class UppercaseColumnConstraint(ColumnConstraintKind):
+ arg_types: t.Dict[str, t.Any] = {}
+
+
+class PathColumnConstraint(ColumnConstraintKind):
pass
@@ -1063,6 +1103,7 @@ class Insert(Expression):
"overwrite": False,
"exists": False,
"partition": False,
+ "alternative": False,
}
@@ -1438,6 +1479,16 @@ class IsolatedLoadingProperty(Property):
}
+class LockingProperty(Property):
+ arg_types = {
+ "this": False,
+ "kind": True,
+ "for_or_in": True,
+ "lock_type": True,
+ "override": False,
+ }
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1463,12 +1514,26 @@ class Properties(Expression):
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
+ # CREATE property locations
+ # Form: schema specified
+ # create [POST_CREATE]
+ # table a [POST_NAME]
+ # (b int) [POST_SCHEMA]
+ # with ([POST_WITH])
+ # index (b) [POST_INDEX]
+ #
+ # Form: alias selection
+ # create [POST_CREATE]
+ # table a [POST_NAME]
+ # as [POST_ALIAS] (select * from b)
+ # index (c) [POST_INDEX]
class Location(AutoName):
POST_CREATE = auto()
- PRE_SCHEMA = auto()
+ POST_NAME = auto()
+ POST_SCHEMA = auto()
+ POST_WITH = auto()
+ POST_ALIAS = auto()
POST_INDEX = auto()
- POST_SCHEMA_ROOT = auto()
- POST_SCHEMA_WITH = auto()
UNSUPPORTED = auto()
@classmethod
@@ -1633,6 +1698,14 @@ class Table(Expression):
"system_time": False,
}
+ @property
+ def db(self) -> str:
+ return self.text("db")
+
+ @property
+ def catalog(self) -> str:
+ return self.text("catalog")
+
# See the TSQL "Querying data in a system-versioned temporal table" page
class SystemTime(Expression):
@@ -1678,6 +1751,40 @@ class Union(Subqueryable):
.limit(expression, dialect=dialect, copy=False, **opts)
)
+ def select(
+ self,
+ *expressions: str | Expression,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Union:
+ """Append to or set the SELECT of the union recursively.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> parse_one("select a from x union select a from y union select a from z").select("b").sql()
+ 'SELECT a, b FROM x UNION SELECT a, b FROM y UNION SELECT a, b FROM z'
+
+ Args:
+ *expressions: the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ append: if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ Union: the modified expression.
+ """
+ this = self.copy() if copy else self
+ this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
+ this.expression.unnest().select(
+ *expressions, append=append, dialect=dialect, copy=False, **opts
+ )
+ return this
+
@property
def named_selects(self):
return self.this.unnest().named_selects
@@ -1985,7 +2092,14 @@ class Select(Subqueryable):
**opts,
)
- def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ def select(
+ self,
+ *expressions: str | Expression,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Select:
"""
Append to or set the SELECT expressions.
@@ -1994,13 +2108,13 @@ class Select(Subqueryable):
'SELECT x, y'
Args:
- *expressions (str | Expression): the SQL code strings to parse.
+ *expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
- append (bool): if `True`, add to any existing expressions.
+ append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
- dialect (str): the dialect used to parse the input expressions.
- copy (bool): if `False`, modify this expression instance in-place.
- opts (kwargs): other options to use to parse the input expressions.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
@@ -2399,7 +2513,7 @@ class Star(Expression):
class Parameter(Expression):
- pass
+ arg_types = {"this": True, "wrapped": False}
class SessionParameter(Expression):
@@ -2428,6 +2542,7 @@ class DataType(Expression):
"expressions": False,
"nested": False,
"values": False,
+ "prefix": False,
}
class Type(AutoName):
@@ -2693,6 +2808,10 @@ class ILike(Binary, Predicate):
pass
+class ILikeAny(Binary, Predicate):
+ pass
+
+
class IntDiv(Binary):
pass
@@ -2709,6 +2828,10 @@ class Like(Binary, Predicate):
pass
+class LikeAny(Binary, Predicate):
+ pass
+
+
class LT(Binary, Predicate):
pass
@@ -3042,7 +3165,7 @@ class DateDiff(Func, TimeUnit):
class DateTrunc(Func):
- arg_types = {"this": True, "expression": True, "zone": False}
+ arg_types = {"unit": True, "this": True, "zone": False}
class DatetimeAdd(Func, TimeUnit):
@@ -3330,6 +3453,16 @@ class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
+class RegexpExtract(Func):
+ arg_types = {
+ "this": True,
+ "expression": True,
+ "position": False,
+ "occurrence": False,
+ "group": False,
+ }
+
+
class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
@@ -3519,6 +3652,10 @@ class Week(Func):
arg_types = {"this": True, "mode": False}
+class XMLTable(Func):
+ arg_types = {"this": True, "passing": False, "columns": False, "by_ref": False}
+
+
class Year(Func):
pass
@@ -3566,6 +3703,7 @@ def maybe_parse(
into: t.Optional[IntoType] = None,
dialect: DialectType = None,
prefix: t.Optional[str] = None,
+ copy: bool = False,
**opts,
) -> Expression:
"""Gracefully handle a possible string or expression.
@@ -3583,6 +3721,7 @@ def maybe_parse(
input expression is a SQL string).
prefix: a string to prefix the sql with before it gets parsed
(automatically includes a space)
+ copy: whether or not to copy the expression.
**opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string).
@@ -3590,6 +3729,8 @@ def maybe_parse(
Expression: the parsed or given expression.
"""
if isinstance(sql_or_expression, Expression):
+ if copy:
+ return sql_or_expression.copy()
return sql_or_expression
import sqlglot
@@ -3818,7 +3959,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
-def select(*expressions, dialect=None, **opts) -> Select:
+def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@@ -3827,9 +3968,9 @@ def select(*expressions, dialect=None, **opts) -> Select:
'SELECT col1, col2 FROM tbl'
Args:
- *expressions (str | Expression): the SQL code string to parse as the expressions of a
+ *expressions: the SQL code string to parse as the expressions of a
SELECT statement. If an Expression instance is passed, this is used as-is.
- dialect (str): the dialect used to parse the input expressions (in the case that an
+ dialect: the dialect used to parse the input expressions (in the case that an
input expression is a SQL string).
**opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string).
@@ -4219,19 +4360,27 @@ def subquery(expression, alias=None, dialect=None, **opts):
return Select().from_(expression, dialect=dialect, **opts)
-def column(col, table=None, quoted=None) -> Column:
+def column(
+ col: str | Identifier,
+ table: t.Optional[str | Identifier] = None,
+ schema: t.Optional[str | Identifier] = None,
+ quoted: t.Optional[bool] = None,
+) -> Column:
"""
Build a Column.
Args:
- col (str | Expression): column name
- table (str | Expression): table name
+ col: column name
+ table: table name
+ schema: schema name
+ quoted: whether or not to force quote each part
Returns:
Column: column instance
"""
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
+ schema=to_identifier(schema, quoted=quoted),
)
@@ -4314,6 +4463,30 @@ def values(
)
+def var(name: t.Optional[str | Expression]) -> Var:
+ """Build a SQL variable.
+
+ Example:
+ >>> repr(var('x'))
+ '(VAR this: x)'
+
+ >>> repr(var(column('x', table='y')))
+ '(VAR this: x)'
+
+ Args:
+ name: The name of the var or an expression who's name will become the var.
+
+ Returns:
+ The new variable node.
+ """
+ if not name:
+ raise ValueError(f"Cannot convert empty name into var.")
+
+ if isinstance(name, Expression):
+ name = name.name
+ return Var(this=name)
+
+
def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
"""Build ALTER TABLE... RENAME... expression
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 0d72fe3..1479e28 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,19 +1,16 @@
from __future__ import annotations
import logging
-import re
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
-from sqlglot.helper import apply_index_offset, csv
+from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
-BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
-
class Generator:
"""
@@ -59,10 +56,14 @@ class Generator:
"""
TRANSFORMS = {
- exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
- exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
- exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
- exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
+ exp.DateAdd: lambda self, e: self.func(
+ "DATE_ADD", e.this, e.expression, e.args.get("unit")
+ ),
+ exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
+ exp.TsOrDsAdd: lambda self, e: self.func(
+ "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit")
+ ),
+ exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
@@ -72,6 +73,17 @@ class Generator:
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
+ exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
+ exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
+ exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
+ exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
+ exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
+ exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
+ exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
+ exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
+ exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
+ exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
+ exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@@ -89,8 +101,8 @@ class Generator:
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
- # Whether or not create function uses an AS before the def.
- CREATE_FUNCTION_AS = True
+ # Whether or not create function uses an AS before the RETURN
+ CREATE_FUNCTION_RETURN_AS = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
@@ -110,42 +122,46 @@ class Generator:
STRUCT_DELIMITER = ("<", ">")
+ PARAMETER_TOKEN = "@"
+
PROPERTIES_LOCATION = {
- exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.AfterJournalProperty: exp.Properties.Location.POST_NAME,
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
- exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
+ exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
+ exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.Cluster: exp.Properties.Location.POST_SCHEMA,
+ exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
- exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
- exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
- exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
- exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
- exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.FallbackProperty: exp.Properties.Location.POST_NAME,
+ exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
+ exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
+ exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
+ exp.JournalProperty: exp.Properties.Location.POST_NAME,
+ exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.LikeProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
+ exp.LogProperty: exp.Properties.Location.POST_NAME,
+ exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
+ exp.Property: exp.Properties.Location.POST_WITH,
+ exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
+ exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
- exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
- exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
- exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
+ exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
@@ -173,7 +189,6 @@ class Generator:
"null_ordering",
"max_unsupported",
"_indent",
- "_replace_backslash",
"_escaped_quote_end",
"_escaped_identifier_end",
"_leading_comma",
@@ -230,7 +245,6 @@ class Generator:
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self._indent = indent
- self._replace_backslash = self.string_escape == "\\"
self._escaped_quote_end = self.string_escape + self.quote_end
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma
@@ -403,12 +417,13 @@ class Generator:
def column_sql(self, expression: exp.Column) -> str:
return ".".join(
- part
- for part in [
- self.sql(expression, "db"),
- self.sql(expression, "table"),
- self.sql(expression, "this"),
- ]
+ self.sql(part)
+ for part in (
+ expression.args.get("catalog"),
+ expression.args.get("db"),
+ expression.args.get("table"),
+ expression.args.get("this"),
+ )
if part
)
@@ -430,26 +445,6 @@ class Generator:
def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
- def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
- this = self.sql(expression, "this")
- return f"CHECK ({this})"
-
- def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str:
- comment = self.sql(expression, "this")
- return f"COMMENT {comment}"
-
- def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str:
- collate = self.sql(expression, "this")
- return f"COLLATE {collate}"
-
- def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str:
- encode = self.sql(expression, "this")
- return f"ENCODE {encode}"
-
- def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str:
- default = self.sql(expression, "this")
- return f"DEFAULT {default}"
-
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
@@ -459,10 +454,19 @@ class Generator:
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
- increment = f"INCREMENT BY {increment}" if increment else ""
+ increment = f" INCREMENT BY {increment}" if increment else ""
+ minvalue = expression.args.get("minvalue")
+ minvalue = f" MINVALUE {minvalue}" if minvalue else ""
+ maxvalue = expression.args.get("maxvalue")
+ maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
+ cycle = expression.args.get("cycle")
+ cycle_sql = ""
+ if cycle is not None:
+ cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
+ cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
sequence_opts = ""
- if start or increment:
- sequence_opts = f"{start} {increment}"
+ if start or increment or cycle_sql:
+ sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
sequence_opts = f" ({sequence_opts.strip()})"
return f"GENERATED{this}AS IDENTITY{sequence_opts}"
@@ -483,22 +487,22 @@ class Generator:
properties = expression.args.get("properties")
properties_exp = expression.copy()
properties_locs = self.locate_properties(properties) if properties else {}
- if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
- exp.Properties.Location.POST_SCHEMA_WITH
+ if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
+ exp.Properties.Location.POST_WITH
):
properties_exp.set(
"properties",
exp.Properties(
expressions=[
- *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
- *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
+ *properties_locs[exp.Properties.Location.POST_SCHEMA],
+ *properties_locs[exp.Properties.Location.POST_WITH],
]
),
)
- if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
+ if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
this_properties = self.properties(
- exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
+ exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]),
wrapped=False,
)
this_schema = f"({self.expressions(expression.this)})"
@@ -512,8 +516,17 @@ class Generator:
if expression_sql:
expression_sql = f"{begin}{self.sep()}{expression_sql}"
- if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
- expression_sql = f" AS{expression_sql}"
+ if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
+ if properties_locs.get(exp.Properties.Location.POST_ALIAS):
+ postalias_props_sql = self.properties(
+ exp.Properties(
+ expressions=properties_locs[exp.Properties.Location.POST_ALIAS]
+ ),
+ wrapped=False,
+ )
+ expression_sql = f" AS {postalias_props_sql}{expression_sql}"
+ else:
+ expression_sql = f" AS{expression_sql}"
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
@@ -736,9 +749,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
- if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
+ if p_loc == exp.Properties.Location.POST_WITH:
with_properties.append(p)
- elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
+ elif p_loc == exp.Properties.Location.POST_SCHEMA:
root_properties.append(p)
return self.root_properties(
@@ -776,16 +789,18 @@ class Generator:
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
- if p_loc == exp.Properties.Location.PRE_SCHEMA:
- properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
+ if p_loc == exp.Properties.Location.POST_NAME:
+ properties_locs[exp.Properties.Location.POST_NAME].append(p)
elif p_loc == exp.Properties.Location.POST_INDEX:
properties_locs[exp.Properties.Location.POST_INDEX].append(p)
- elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
- properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
- elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
- properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
+ elif p_loc == exp.Properties.Location.POST_SCHEMA:
+ properties_locs[exp.Properties.Location.POST_SCHEMA].append(p)
+ elif p_loc == exp.Properties.Location.POST_WITH:
+ properties_locs[exp.Properties.Location.POST_WITH].append(p)
elif p_loc == exp.Properties.Location.POST_CREATE:
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
+ elif p_loc == exp.Properties.Location.POST_ALIAS:
+ properties_locs[exp.Properties.Location.POST_ALIAS].append(p)
elif p_loc == exp.Properties.Location.UNSUPPORTED:
self.unsupported(f"Unsupported property {p.key}")
@@ -899,6 +914,14 @@ class Generator:
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
+ def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
+ kind = expression.args.get("kind")
+ this: str = f" {this}" if expression.this else ""
+ for_or_in = expression.args.get("for_or_in")
+ lock_type = expression.args.get("lock_type")
+ override = " OVERRIDE" if expression.args.get("override") else ""
+ return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
+
def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
@@ -907,14 +930,17 @@ class Generator:
else:
this = "OVERWRITE TABLE " if overwrite else "INTO "
+ alternative = expression.args.get("alternative")
+ alternative = f" OR {alternative} " if alternative else " "
this = f"{this}{self.sql(expression, 'this')}"
+
exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = (
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
- sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
+ sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@@ -1046,21 +1072,26 @@ class Generator:
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
- cube = expression.args.get("cube")
- if cube is True:
- cube = self.seg("WITH CUBE")
+ cube = expression.args.get("cube", [])
+ if seq_get(cube, 0) is True:
+ return f"{group_by}{self.seg('WITH CUBE')}"
else:
- cube = self.expressions(expression, key="cube", indent=False)
- cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
+ cube_sql = self.expressions(expression, key="cube", indent=False)
+ cube_sql = f"{self.seg('CUBE')} {self.wrap(cube_sql)}" if cube_sql else ""
- rollup = expression.args.get("rollup")
- if rollup is True:
- rollup = self.seg("WITH ROLLUP")
+ rollup = expression.args.get("rollup", [])
+ if seq_get(rollup, 0) is True:
+ return f"{group_by}{self.seg('WITH ROLLUP')}"
else:
- rollup = self.expressions(expression, key="rollup", indent=False)
- rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
+ rollup_sql = self.expressions(expression, key="rollup", indent=False)
+ rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
+
+ groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",")
- return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
+ if expression.args.get("expressions") and groupings:
+ group_by = f"{group_by},"
+
+ return f"{group_by}{groupings}"
def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
@@ -1139,8 +1170,6 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
- if self._replace_backslash:
- text = BACKSLASH_RE.sub(r"\\\\", text)
text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
@@ -1291,7 +1320,9 @@ class Generator:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def parameter_sql(self, expression: exp.Parameter) -> str:
- return f"@{self.sql(expression, 'this')}"
+ this = self.sql(expression, "this")
+ this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
+ return f"{self.PARAMETER_TOKEN}{this}"
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
@@ -1405,7 +1436,10 @@ class Generator:
return f"ALL {self.wrap(expression)}"
def any_sql(self, expression: exp.Any) -> str:
- return f"ANY {self.wrap(expression)}"
+ this = self.sql(expression, "this")
+ if isinstance(expression.this, exp.Subqueryable):
+ this = self.wrap(this)
+ return f"ANY {this}"
def exists_sql(self, expression: exp.Exists) -> str:
return f"EXISTS{self.wrap(expression)}"
@@ -1444,11 +1478,11 @@ class Generator:
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
- return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
+ return self.func("LTRIM", expression.this)
elif trim_type == "TRAILING":
- return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
+ return self.func("RTRIM", expression.this)
else:
- return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
+ return self.func("TRIM", expression.this, expression.expression)
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
@@ -1530,8 +1564,7 @@ class Generator:
return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str:
- args = self.format_args(*expression.expressions)
- return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
+ return self.func(expression.name, *expression.expressions)
def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
@@ -1792,7 +1825,10 @@ class Generator:
else:
args.append(arg_value)
- return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
+ return self.func(expression.sql_name(), *args)
+
+ def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
+ return f"{self.normalize_func(name)}({self.format_args(*args)})"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
@@ -1848,6 +1884,7 @@ class Generator:
return self.indent(result_sql, skip_first=False) if indent else result_sql
def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
+ flat = flat or isinstance(expression.parent, exp.Properties)
expressions_sql = self.expressions(expression, flat=flat)
if flat:
return f"{op} {expressions_sql}"
@@ -1880,11 +1917,6 @@ class Generator:
)
return f"{this}{expressions}"
- def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
- this = self.sql(expression, "this")
- kind = self.sql(expression, "kind")
- return f"{this} {kind}"
-
def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 66f97a9..be65ab9 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -280,6 +280,9 @@ class TypeAnnotator:
}
# First annotate the current scope's column references
for col in scope.columns:
+ if not col.table:
+ continue
+
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index c6bea5a..6f9db82 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -81,9 +81,7 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
- for scope in itertools.chain(
- root.union_scopes, root.subquery_scopes, root.derived_table_scopes
- ):
+ for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
@@ -99,7 +97,7 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
- if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
+ if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 96fd56b..d9d04be 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,4 +1,10 @@
+from __future__ import annotations
+
+import typing as t
+
import sqlglot
+from sqlglot import Schema, exp
+from sqlglot.dialects.dialect import DialectType
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
@@ -24,8 +30,8 @@ RULES = (
isolate_table_selects,
qualify_columns,
expand_laterals,
- validate_qualify_columns,
pushdown_projections,
+ validate_qualify_columns,
normalize,
unnest_subqueries,
expand_multi_table_selects,
@@ -40,22 +46,31 @@ RULES = (
)
-def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs):
+def optimize(
+ expression: str | exp.Expression,
+ schema: t.Optional[dict | Schema] = None,
+ db: t.Optional[str] = None,
+ catalog: t.Optional[str] = None,
+ dialect: DialectType = None,
+ rules: t.Sequence[t.Callable] = RULES,
+ **kwargs,
+):
"""
Rewrite a sqlglot AST into an optimized form.
Args:
- expression (sqlglot.Expression): expression to optimize
- schema (dict|sqlglot.optimizer.Schema): database schema.
+ expression: expression to optimize
+ schema: database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
- db (str): specify the default database, as might be set by a `USE DATABASE db` statement
- catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
- rules (sequence): sequence of optimizer rules to use.
+ db: specify the default database, as might be set by a `USE DATABASE db` statement
+ catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement
+ dialect: The dialect to parse the sql string.
+ rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
@@ -65,7 +80,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
"""
schema = ensure_schema(schema or sqlglot.schema)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
- expression = expression.copy()
+ expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 54c5021..3f360f9 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -1,7 +1,10 @@
from collections import defaultdict
from sqlglot import alias, exp
+from sqlglot.helper import flatten
+from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import ensure_schema
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
@@ -10,7 +13,7 @@ SELECT_ALL = object()
DEFAULT_SELECTION = lambda: alias("1", "_")
-def pushdown_projections(expression):
+def pushdown_projections(expression, schema=None):
"""
Rewrite sqlglot AST to remove unused columns projections.
@@ -27,9 +30,9 @@ def pushdown_projections(expression):
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
+ schema = ensure_schema(schema)
referenced_columns = defaultdict(set)
- left_union = None
- right_union = None
+
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
@@ -41,16 +44,20 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
- left_union, right_union = scope.union_scopes
- referenced_columns[left_union] = parent_selections
- referenced_columns[right_union] = parent_selections
+ left, right = scope.union_scopes
+ referenced_columns[left] = parent_selections
+
+ if any(select.is_star for select in right.selects):
+ referenced_columns[right] = parent_selections
+ elif not any(select.is_star for select in left.selects):
+ referenced_columns[right] = [
+ right.selects[i].alias_or_name
+ for i, select in enumerate(left.selects)
+ if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
+ ]
- if isinstance(scope.expression, exp.Select) and scope != right_union:
- removed_indexes = _remove_unused_selections(scope, parent_selections)
- # The left union is used for column names to select and if we remove columns from the left
- # we need to also remove those same columns in the right that were at the same position
- if scope is left_union:
- _remove_indexed_selections(right_union, removed_indexes)
+ if isinstance(scope.expression, exp.Select):
+ _remove_unused_selections(scope, parent_selections, schema)
# Group columns by source name
selects = defaultdict(set)
@@ -68,8 +75,7 @@ def pushdown_projections(expression):
return expression
-def _remove_unused_selections(scope, parent_selections):
- removed_indexes = []
+def _remove_unused_selections(scope, parent_selections, schema):
order = scope.expression.args.get("order")
if order:
@@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections):
else:
order_refs = set()
- new_selections = []
+ new_selections = defaultdict(list)
removed = False
- for i, selection in enumerate(scope.selects):
- if (
- SELECT_ALL in parent_selections
- or selection.alias_or_name in parent_selections
- or selection.alias_or_name in order_refs
- ):
- new_selections.append(selection)
+ star = False
+ for selection in scope.selects:
+ name = selection.alias_or_name
+
+ if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
+ new_selections[name].append(selection)
else:
- removed_indexes.append(i)
+ if selection.is_star:
+ star = True
removed = True
+ if star:
+ resolver = Resolver(scope, schema)
+
+ for name in sorted(parent_selections):
+ if name not in new_selections:
+ new_selections[name].append(
+ alias(exp.column(name, table=resolver.get_table(name)), name)
+ )
+
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION())
+ new_selections[""].append(DEFAULT_SELECTION())
+
+ scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
- scope.expression.set("expressions", new_selections)
if removed:
scope.clear_cache()
- return removed_indexes
-
-
-def _remove_indexed_selections(scope, indexes_to_remove):
- new_selections = [
- selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
- ]
- if not new_selections:
- new_selections.append(DEFAULT_SELECTION())
- scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ab13d01..a7bd9b5 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -27,17 +27,16 @@ def qualify_columns(expression, schema):
schema = ensure_schema(schema)
for scope in traverse_scope(expression):
- resolver = _Resolver(scope, schema)
+ resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver)
- _expand_group_by(scope, resolver)
_qualify_columns(scope, resolver)
- _expand_order_by(scope)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
-
+ _expand_group_by(scope, resolver)
+ _expand_order_by(scope)
return expression
@@ -48,7 +47,8 @@ def validate_qualify_columns(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
- raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
+ column = scope.external_columns[0]
+ raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
@@ -62,8 +62,6 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
- if isinstance(derived_table.unnest(), exp.UDTF):
- continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
@@ -206,7 +204,7 @@ def _qualify_columns(scope, resolver):
if column_table and column_table in scope.sources:
source_columns = resolver.get_source_columns(column_table)
- if source_columns and column_name not in source_columns:
+ if source_columns and column_name not in source_columns and "*" not in source_columns:
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
@@ -256,7 +254,7 @@ def _expand_stars(scope, resolver):
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
- elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
+ elif expression.is_star:
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
@@ -268,17 +266,16 @@ def _expand_stars(scope, resolver):
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
- if not columns:
- raise OptimizeError(
- f"Table has no schema/columns. Cannot expand star for table: {table}."
- )
- table_id = id(table)
- for name in columns:
- if name not in except_columns.get(table_id, set()):
- alias_ = replace_columns.get(table_id, {}).get(name, name)
- column = exp.column(name, table)
- new_selections.append(alias(column, alias_) if alias_ != name else column)
+ if columns and "*" not in columns:
+ table_id = id(table)
+ for name in columns:
+ if name not in except_columns.get(table_id, set()):
+ alias_ = replace_columns.get(table_id, {}).get(name, name)
+ column = exp.column(name, table)
+ new_selections.append(alias(column, alias_) if alias_ != name else column)
+ else:
+ return
scope.expression.set("expressions", new_selections)
@@ -316,7 +313,7 @@ def _qualify_outputs(scope):
if isinstance(selection, exp.Subquery):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
- elif not isinstance(selection, exp.Alias):
+ elif not isinstance(selection, exp.Alias) and not selection.is_star:
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection)
selection = alias_
@@ -329,7 +326,7 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
-class _Resolver:
+class Resolver:
"""
Helper for resolving columns.
@@ -361,7 +358,9 @@ class _Resolver:
if not table:
sources_without_schema = tuple(
- source for source, columns in self._get_all_source_columns().items() if not columns
+ source
+ for source, columns in self._get_all_source_columns().items()
+ if not columns or "*" in columns
)
if len(sources_without_schema) == 1:
return sources_without_schema[0]
@@ -397,7 +396,8 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
self._source_columns = {
- k: self.get_source_columns(k) for k in self.scope.selected_sources
+ k: self.get_source_columns(k)
+ for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
}
return self._source_columns
@@ -436,7 +436,7 @@ class _Resolver:
Find the unique columns in a list of columns.
Example:
- >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
+ >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
['a', 'c']
This is necessary because duplicate column names are ambiguous.
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 65593bd..6e50182 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -28,7 +28,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
next_name = lambda: f"_q_{next(sequence)}"
for scope in traverse_scope(expression):
- for derived_table in scope.ctes + scope.derived_tables:
+ for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 8565c64..335ff3e 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -26,6 +26,10 @@ class Scope:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
+ lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
+ For example:
+ SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
+ The LATERAL VIEW EXPLODE gets x as a source.
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
@@ -34,8 +38,10 @@ class Scope:
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes = (list[Scope]) List of all child scopes for CTEs
- derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
+ cte_scopes (list[Scope]): List of all child scopes for CTEs
+ derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
+ udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
+ table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
@@ -47,22 +53,28 @@ class Scope:
outer_column_list=None,
parent=None,
scope_type=ScopeType.ROOT,
+ lateral_sources=None,
):
self.expression = expression
self.sources = sources or {}
+ self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
+ self.sources.update(self.lateral_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.derived_table_scopes = []
+ self.table_scopes = []
self.cte_scopes = []
self.union_scopes = []
+ self.udtf_scopes = []
self.clear_cache()
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
+ self._udtfs = None
self._tables = None
self._ctes = None
self._subqueries = None
@@ -86,6 +98,7 @@ class Scope:
self._ctes = []
self._subqueries = []
self._derived_tables = []
+ self._udtfs = []
self._raw_columns = []
self._join_hints = []
@@ -99,7 +112,7 @@ class Scope:
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
elif isinstance(node, exp.UDTF):
- self._derived_tables.append(node)
+ self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
@@ -200,6 +213,17 @@ class Scope:
return self._derived_tables
@property
+ def udtfs(self):
+ """
+ List of "User Defined Tabular Functions" in this scope.
+
+ Returns:
+ list[exp.UDTF]: UDTFs
+ """
+ self._ensure_collected()
+ return self._udtfs
+
+ @property
def subqueries(self):
"""
List of subqueries in this scope.
@@ -227,7 +251,9 @@ class Scope:
columns = self._raw_columns
external_columns = [
- column for scope in self.subquery_scopes for column in scope.external_columns
+ column
+ for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
+ for column in scope.external_columns
]
named_selects = set(self.expression.named_selects)
@@ -262,9 +288,8 @@ class Scope:
for table in self.tables:
referenced_names.append((table.alias_or_name, table))
- for derived_table in self.derived_tables:
- referenced_names.append((derived_table.alias, derived_table.unnest()))
-
+ for expression in itertools.chain(self.derived_tables, self.udtfs):
+ referenced_names.append((expression.alias, expression.unnest()))
result = {}
for name, node in referenced_names:
@@ -414,7 +439,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
+ self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@@ -480,24 +505,23 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
- elif isinstance(scope.expression, exp.UDTF):
- _set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
+ elif isinstance(scope.expression, exp.UDTF):
+ pass
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
yield scope
def _traverse_select(scope):
- yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
- yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
+ yield from _traverse_ctes(scope)
+ yield from _traverse_tables(scope)
yield from _traverse_subqueries(scope)
- _add_table_sources(scope)
def _traverse_union(scope):
- yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
+ yield from _traverse_ctes(scope)
# The last scope to be yield should be the top most scope
left = None
@@ -511,82 +535,98 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]
-def _set_udtf_scope(scope):
- parent = scope.expression.parent
- from_ = parent.args.get("from")
-
- if not from_:
- return
-
- for table in from_.expressions:
- if isinstance(table, exp.Table):
- scope.tables.append(table)
- elif isinstance(table, exp.Subquery):
- scope.subqueries.append(table)
- _add_table_sources(scope)
- _traverse_subqueries(scope)
-
-
-def _traverse_derived_tables(derived_tables, scope, scope_type):
+def _traverse_ctes(scope):
sources = {}
- is_cte = scope_type == ScopeType.CTE
- for derived_table in derived_tables:
+ for cte in scope.ctes:
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
- if is_cte and scope.expression.args["with"].recursive:
- union = derived_table.this
+ if scope.expression.args["with"].recursive:
+ union = cte.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
for child_scope in _traverse_scope(
scope.branch(
- derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
- chain_sources=sources if scope_type == ScopeType.CTE else None,
- outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
+ cte.this,
+ chain_sources=sources,
+ outer_column_list=cte.alias_column_names,
+ scope_type=ScopeType.CTE,
)
):
yield child_scope
- # Tables without aliases will be set as ""
- # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
- # Until then, this means that only a single, unaliased derived table is allowed (rather,
- # the latest one wins.
- alias = derived_table.alias
+ alias = cte.alias
sources[alias] = child_scope
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
- if is_cte:
- scope.cte_scopes.append(child_scope)
- else:
- scope.derived_table_scopes.append(child_scope)
+ scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
-def _add_table_sources(scope):
+def _traverse_tables(scope):
sources = {}
- for table in scope.tables:
- table_name = table.name
- if table.alias:
- source_name = table.alias
- else:
- source_name = table_name
+ # Traverse FROMs, JOINs, and LATERALs in the order they are defined
+ expressions = []
+ from_ = scope.expression.args.get("from")
+ if from_:
+ expressions.extend(from_.expressions)
- if table_name in scope.sources:
- # This is a reference to a parent source (e.g. a CTE), not an actual table.
- scope.sources[source_name] = scope.sources[table_name]
+ for join in scope.expression.args.get("joins") or []:
+ expressions.append(join.this)
+
+ expressions.extend(scope.expression.args.get("laterals") or [])
+
+ for expression in expressions:
+ if isinstance(expression, exp.Table):
+ table_name = expression.name
+ source_name = expression.alias_or_name
+
+ if table_name in scope.sources:
+ # This is a reference to a parent source (e.g. a CTE), not an actual table.
+ sources[source_name] = scope.sources[table_name]
+ else:
+ sources[source_name] = expression
+ continue
+
+ if isinstance(expression, exp.UDTF):
+ lateral_sources = sources
+ scope_type = ScopeType.UDTF
+ scopes = scope.udtf_scopes
else:
- sources[source_name] = table
+ lateral_sources = None
+ scope_type = ScopeType.DERIVED_TABLE
+ scopes = scope.derived_table_scopes
+
+ for child_scope in _traverse_scope(
+ scope.branch(
+ expression,
+ lateral_sources=lateral_sources,
+ outer_column_list=expression.alias_column_names,
+ scope_type=scope_type,
+ )
+ ):
+ yield child_scope
+
+ # Tables without aliases will be set as ""
+ # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
+ # Until then, this means that only a single, unaliased derived table is allowed (rather,
+ # the latest one wins.
+ alias = expression.alias
+ sources[alias] = child_scope
+
+ # append the final child_scope yielded
+ scopes.append(child_scope)
+ scope.table_scopes.append(child_scope)
scope.sources.update(sources)
@@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True):
if node is expression:
continue
- elif isinstance(node, exp.CTE):
- prune = True
- elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
- prune = True
- elif isinstance(node, exp.Subqueryable):
+ if (
+ isinstance(node, exp.CTE)
+ or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
+ or isinstance(node, exp.UDTF)
+ or isinstance(node, exp.Subqueryable)
+ ):
prune = True
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 579c2ce..9bde696 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import typing as t
+from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
@@ -157,7 +158,6 @@ class Parser(metaclass=_Parser):
ID_VAR_TOKENS = {
TokenType.VAR,
- TokenType.ALWAYS,
TokenType.ANTI,
TokenType.APPLY,
TokenType.AUTO_INCREMENT,
@@ -186,8 +186,6 @@ class Parser(metaclass=_Parser):
TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FUNCTION,
- TokenType.GENERATED,
- TokenType.IDENTITY,
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
@@ -213,7 +211,6 @@ class Parser(metaclass=_Parser):
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
- TokenType.SCHEMA_COMMENT,
TokenType.SEED,
TokenType.SEMI,
TokenType.SET,
@@ -481,9 +478,7 @@ class Parser(metaclass=_Parser):
PLACEHOLDER_PARSERS = {
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
- TokenType.PARAMETER: lambda self: self.expression(
- exp.Parameter, this=self._parse_var() or self._parse_primary()
- ),
+ TokenType.PARAMETER: lambda self: self._parse_parameter(),
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
if self._match_set((TokenType.NUMBER, TokenType.VAR))
else None,
@@ -516,6 +511,9 @@ class Parser(metaclass=_Parser):
PROPERTY_PARSERS = {
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"CHARACTER SET": lambda self: self._parse_character_set(),
+ "CLUSTER BY": lambda self: self.expression(
+ exp.Cluster, expressions=self._parse_csv(self._parse_ordered)
+ ),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
@@ -576,20 +574,54 @@ class Parser(metaclass=_Parser):
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"DEFINER": lambda self: self._parse_definer(),
+ "LOCK": lambda self: self._parse_locking(),
+ "LOCKING": lambda self: self._parse_locking(),
}
CONSTRAINT_PARSERS = {
- TokenType.CHECK: lambda self: self.expression(
- exp.Check, this=self._parse_wrapped(self._parse_conjunction)
+ "AUTOINCREMENT": lambda self: self._parse_auto_increment(),
+ "AUTO_INCREMENT": lambda self: self._parse_auto_increment(),
+ "CASESPECIFIC": lambda self: self.expression(exp.CaseSpecificColumnConstraint, not_=False),
+ "CHARACTER SET": lambda self: self.expression(
+ exp.CharacterSetColumnConstraint, this=self._parse_var_or_string()
+ ),
+ "CHECK": lambda self: self.expression(
+ exp.CheckColumnConstraint, this=self._parse_wrapped(self._parse_conjunction)
+ ),
+ "COLLATE": lambda self: self.expression(
+ exp.CollateColumnConstraint, this=self._parse_var()
+ ),
+ "COMMENT": lambda self: self.expression(
+ exp.CommentColumnConstraint, this=self._parse_string()
+ ),
+ "DEFAULT": lambda self: self.expression(
+ exp.DefaultColumnConstraint, this=self._parse_bitwise()
),
- TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
- TokenType.UNIQUE: lambda self: self._parse_unique(),
- TokenType.LIKE: lambda self: self._parse_create_like(),
+ "ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()),
+ "FOREIGN KEY": lambda self: self._parse_foreign_key(),
+ "FORMAT": lambda self: self.expression(
+ exp.DateFormatColumnConstraint, this=self._parse_var_or_string()
+ ),
+ "GENERATED": lambda self: self._parse_generated_as_identity(),
+ "IDENTITY": lambda self: self._parse_auto_increment(),
+ "LIKE": lambda self: self._parse_create_like(),
+ "NOT": lambda self: self._parse_not_constraint(),
+ "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
+ "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
+ "PRIMARY KEY": lambda self: self._parse_primary_key(),
+ "TITLE": lambda self: self.expression(
+ exp.TitleColumnConstraint, this=self._parse_var_or_string()
+ ),
+ "UNIQUE": lambda self: self._parse_unique(),
+ "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
}
+ SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
+
NO_PAREN_FUNCTION_PARSERS = {
TokenType.CASE: lambda self: self._parse_case(),
TokenType.IF: lambda self: self._parse_if(),
+ TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -637,6 +669,8 @@ class Parser(metaclass=_Parser):
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+ INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
+
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@@ -940,7 +974,9 @@ class Parser(metaclass=_Parser):
def _parse_create(self) -> t.Optional[exp.Expression]:
start = self._prev
- replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
+ replace = self._prev.text.upper() == "REPLACE" or self._match_pair(
+ TokenType.OR, TokenType.REPLACE
+ )
set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata
global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata
@@ -958,7 +994,7 @@ class Parser(metaclass=_Parser):
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
- properties = self._parse_properties()
+ properties = self._parse_properties() # exp.Properties.Location.POST_CREATE
create_token = self._match_set(self.CREATABLES) and self._prev
if not properties or not create_token:
@@ -994,15 +1030,37 @@ class Parser(metaclass=_Parser):
):
table_parts = self._parse_table_parts(schema=True)
- if self._match(TokenType.COMMA): # comma-separated properties before schema definition
- properties = self._parse_properties(before=True)
+ # exp.Properties.Location.POST_NAME
+ if self._match(TokenType.COMMA):
+ temp_properties = self._parse_properties(before=True)
+ if properties and temp_properties:
+ properties.expressions.append(temp_properties.expressions)
+ elif temp_properties:
+ properties = temp_properties
this = self._parse_schema(this=table_parts)
- if not properties: # properties after schema definition
- properties = self._parse_properties()
+ # exp.Properties.Location.POST_SCHEMA and POST_WITH
+ temp_properties = self._parse_properties()
+ if properties and temp_properties:
+ properties.expressions.append(temp_properties.expressions)
+ elif temp_properties:
+ properties = temp_properties
self._match(TokenType.ALIAS)
+
+ # exp.Properties.Location.POST_ALIAS
+ if not (
+ self._match(TokenType.SELECT, advance=False)
+ or self._match(TokenType.WITH, advance=False)
+ or self._match(TokenType.L_PAREN, advance=False)
+ ):
+ temp_properties = self._parse_properties()
+ if properties and temp_properties:
+ properties.expressions.append(temp_properties.expressions)
+ elif temp_properties:
+ properties = temp_properties
+
expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
@@ -1022,12 +1080,13 @@ class Parser(metaclass=_Parser):
while True:
index = self._parse_create_table_index()
- # post index PARTITION BY property
+ # exp.Properties.Location.POST_INDEX
if self._match(TokenType.PARTITION_BY, advance=False):
- if properties:
- properties.expressions.append(self._parse_property())
- else:
- properties = self._parse_properties()
+ temp_properties = self._parse_properties()
+ if properties and temp_properties:
+ properties.expressions.append(temp_properties.expressions)
+ elif temp_properties:
+ properties = temp_properties
if not index:
break
@@ -1080,7 +1139,7 @@ class Parser(metaclass=_Parser):
return self.PROPERTY_PARSERS[self._prev.text.upper()](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
- return self._parse_character_set(True)
+ return self._parse_character_set(default=True)
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
@@ -1240,7 +1299,7 @@ class Parser(metaclass=_Parser):
def _parse_blockcompression(self) -> exp.Expression:
self._match_text_seq("BLOCKCOMPRESSION")
self._match(TokenType.EQ)
- always = self._match(TokenType.ALWAYS)
+ always = self._match_text_seq("ALWAYS")
manual = self._match_text_seq("MANUAL")
never = self._match_text_seq("NEVER")
default = self._match_text_seq("DEFAULT")
@@ -1274,6 +1333,56 @@ class Parser(metaclass=_Parser):
for_none=for_none,
)
+ def _parse_locking(self) -> exp.Expression:
+ if self._match(TokenType.TABLE):
+ kind = "TABLE"
+ elif self._match(TokenType.VIEW):
+ kind = "VIEW"
+ elif self._match(TokenType.ROW):
+ kind = "ROW"
+ elif self._match_text_seq("DATABASE"):
+ kind = "DATABASE"
+ else:
+ kind = None
+
+ if kind in ("DATABASE", "TABLE", "VIEW"):
+ this = self._parse_table_parts()
+ else:
+ this = None
+
+ if self._match(TokenType.FOR):
+ for_or_in = "FOR"
+ elif self._match(TokenType.IN):
+ for_or_in = "IN"
+ else:
+ for_or_in = None
+
+ if self._match_text_seq("ACCESS"):
+ lock_type = "ACCESS"
+ elif self._match_texts(("EXCL", "EXCLUSIVE")):
+ lock_type = "EXCLUSIVE"
+ elif self._match_text_seq("SHARE"):
+ lock_type = "SHARE"
+ elif self._match_text_seq("READ"):
+ lock_type = "READ"
+ elif self._match_text_seq("WRITE"):
+ lock_type = "WRITE"
+ elif self._match_text_seq("CHECKSUM"):
+ lock_type = "CHECKSUM"
+ else:
+ lock_type = None
+
+ override = self._match_text_seq("OVERRIDE")
+
+ return self.expression(
+ exp.LockingProperty,
+ this=this,
+ kind=kind,
+ for_or_in=for_or_in,
+ lock_type=lock_type,
+ override=override,
+ )
+
def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
if self._match(TokenType.PARTITION_BY):
return self._parse_csv(self._parse_conjunction)
@@ -1351,6 +1460,7 @@ class Parser(metaclass=_Parser):
this: t.Optional[exp.Expression]
+ alternative = None
if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
@@ -1359,6 +1469,9 @@ class Parser(metaclass=_Parser):
row_format=self._parse_row_format(match_row=True),
)
else:
+ if self._match(TokenType.OR):
+ alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text
+
self._match(TokenType.INTO)
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)
@@ -1370,6 +1483,7 @@ class Parser(metaclass=_Parser):
partition=self._parse_partition(),
expression=self._parse_ddl_select(),
overwrite=overwrite,
+ alternative=alternative,
)
def _parse_row(self) -> t.Optional[exp.Expression]:
@@ -1607,7 +1721,7 @@ class Parser(metaclass=_Parser):
index = self._index
if self._match(TokenType.L_PAREN):
- columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
+ columns = self._parse_csv(self._parse_function_parameter)
self._match_r_paren() if columns else self._retreat(index)
else:
columns = None
@@ -2080,27 +2194,33 @@ class Parser(metaclass=_Parser):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
- expressions = self._parse_csv(self._parse_conjunction)
- grouping_sets = self._parse_grouping_sets()
+ elements = defaultdict(list)
- self._match(TokenType.COMMA)
- with_ = self._match(TokenType.WITH)
- cube = self._match(TokenType.CUBE) and (
- with_ or self._parse_wrapped_csv(self._parse_column)
- )
+ while True:
+ expressions = self._parse_csv(self._parse_conjunction)
+ if expressions:
+ elements["expressions"].extend(expressions)
- self._match(TokenType.COMMA)
- rollup = self._match(TokenType.ROLLUP) and (
- with_ or self._parse_wrapped_csv(self._parse_column)
- )
+ grouping_sets = self._parse_grouping_sets()
+ if grouping_sets:
+ elements["grouping_sets"].extend(grouping_sets)
- return self.expression(
- exp.Group,
- expressions=expressions,
- grouping_sets=grouping_sets,
- cube=cube,
- rollup=rollup,
- )
+ rollup = None
+ cube = None
+
+ with_ = self._match(TokenType.WITH)
+ if self._match(TokenType.ROLLUP):
+ rollup = with_ or self._parse_wrapped_csv(self._parse_column)
+ elements["rollup"].extend(ensure_list(rollup))
+
+ if self._match(TokenType.CUBE):
+ cube = with_ or self._parse_wrapped_csv(self._parse_column)
+ elements["cube"].extend(ensure_list(cube))
+
+ if not (expressions or grouping_sets or rollup or cube):
+ break
+
+ return self.expression(exp.Group, **elements) # type: ignore
def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.GROUPING_SETS):
@@ -2357,6 +2477,8 @@ class Parser(metaclass=_Parser):
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
index = self._index
+ prefix = self._match_text_seq("SYSUDTLIB", ".")
+
if not self._match_set(self.TYPE_TOKENS):
return None
@@ -2458,6 +2580,7 @@ class Parser(metaclass=_Parser):
expressions=expressions,
nested=nested,
values=values,
+ prefix=prefix,
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
@@ -2512,8 +2635,14 @@ class Parser(metaclass=_Parser):
if op:
this = op(self, this, field)
- elif isinstance(this, exp.Column) and not this.table:
- this = self.expression(exp.Column, this=field, table=this.this)
+ elif isinstance(this, exp.Column) and not this.args.get("catalog"):
+ this = self.expression(
+ exp.Column,
+ this=field,
+ table=this.this,
+ db=this.args.get("table"),
+ catalog=this.args.get("db"),
+ )
else:
this = self.expression(exp.Dot, this=this, expression=field)
this = self._parse_bracket(this)
@@ -2632,6 +2761,9 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
+ def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
+ return self._parse_column_def(self._parse_id_var())
+
def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None
) -> t.Optional[exp.Expression]:
@@ -2643,7 +2775,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return this
- expressions = self._parse_csv(self._parse_udf_kwarg)
+ expressions = self._parse_csv(self._parse_function_parameter)
self._match_r_paren()
return self.expression(
exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
@@ -2669,15 +2801,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.SessionParameter, this=this, kind=kind)
- def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]:
- this = self._parse_id_var()
- kind = self._parse_types()
-
- if not kind:
- return this
-
- return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
-
def _parse_lambda(self) -> t.Optional[exp.Expression]:
index = self._index
@@ -2726,6 +2849,9 @@ class Parser(metaclass=_Parser):
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
kind = self._parse_types()
+ if self._match_text_seq("FOR", "ORDINALITY"):
+ return self.expression(exp.ColumnDef, this=this, ordinality=True)
+
constraints = []
while True:
constraint = self._parse_column_constraint()
@@ -2738,79 +2864,78 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
- def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
- this = self._parse_references()
+ def _parse_auto_increment(self) -> exp.Expression:
+ start = None
+ increment = None
- if this:
- return this
+ if self._match(TokenType.L_PAREN, advance=False):
+ args = self._parse_wrapped_csv(self._parse_bitwise)
+ start = seq_get(args, 0)
+ increment = seq_get(args, 1)
+ elif self._match_text_seq("START"):
+ start = self._parse_bitwise()
+ self._match_text_seq("INCREMENT")
+ increment = self._parse_bitwise()
+
+ if start and increment:
+ return exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
+
+ return exp.AutoIncrementColumnConstraint()
+
+ def _parse_generated_as_identity(self) -> exp.Expression:
+ if self._match(TokenType.BY_DEFAULT):
+ this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
+ else:
+ self._match_text_seq("ALWAYS")
+ this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
- if self._match(TokenType.CONSTRAINT):
- this = self._parse_id_var()
+ self._match_text_seq("AS", "IDENTITY")
+ if self._match(TokenType.L_PAREN):
+ if self._match_text_seq("START", "WITH"):
+ this.set("start", self._parse_bitwise())
+ if self._match_text_seq("INCREMENT", "BY"):
+ this.set("increment", self._parse_bitwise())
+ if self._match_text_seq("MINVALUE"):
+ this.set("minvalue", self._parse_bitwise())
+ if self._match_text_seq("MAXVALUE"):
+ this.set("maxvalue", self._parse_bitwise())
+
+ if self._match_text_seq("CYCLE"):
+ this.set("cycle", True)
+ elif self._match_text_seq("NO", "CYCLE"):
+ this.set("cycle", False)
- kind: exp.Expression
+ self._match_r_paren()
- if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)):
- start = None
- increment = None
+ return this
- if self._match(TokenType.L_PAREN, advance=False):
- args = self._parse_wrapped_csv(self._parse_bitwise)
- start = seq_get(args, 0)
- increment = seq_get(args, 1)
- elif self._match_text_seq("START"):
- start = self._parse_bitwise()
- self._match_text_seq("INCREMENT")
- increment = self._parse_bitwise()
+ def _parse_not_constraint(self) -> t.Optional[exp.Expression]:
+ if self._match_text_seq("NULL"):
+ return self.expression(exp.NotNullColumnConstraint)
+ if self._match_text_seq("CASESPECIFIC"):
+ return self.expression(exp.CaseSpecificColumnConstraint, not_=True)
+ return None
- if start and increment:
- kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
- else:
- kind = exp.AutoIncrementColumnConstraint()
- elif self._match(TokenType.CHECK):
- constraint = self._parse_wrapped(self._parse_conjunction)
- kind = self.expression(exp.CheckColumnConstraint, this=constraint)
- elif self._match(TokenType.COLLATE):
- kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
- elif self._match(TokenType.ENCODE):
- kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
- elif self._match(TokenType.DEFAULT):
- kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
- elif self._match_pair(TokenType.NOT, TokenType.NULL):
- kind = exp.NotNullColumnConstraint()
- elif self._match(TokenType.NULL):
- kind = exp.NotNullColumnConstraint(allow_null=True)
- elif self._match(TokenType.SCHEMA_COMMENT):
- kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
- elif self._match(TokenType.PRIMARY_KEY):
- desc = None
- if self._match(TokenType.ASC) or self._match(TokenType.DESC):
- desc = self._prev.token_type == TokenType.DESC
- kind = exp.PrimaryKeyColumnConstraint(desc=desc)
- elif self._match(TokenType.UNIQUE):
- kind = exp.UniqueColumnConstraint()
- elif self._match(TokenType.GENERATED):
- if self._match(TokenType.BY_DEFAULT):
- kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
- else:
- self._match(TokenType.ALWAYS)
- kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
- self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
+ def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
+ this = self._parse_references()
+ if this:
+ return this
- if self._match(TokenType.L_PAREN):
- if self._match_text_seq("START", "WITH"):
- kind.set("start", self._parse_bitwise())
- if self._match_text_seq("INCREMENT", "BY"):
- kind.set("increment", self._parse_bitwise())
+ if self._match(TokenType.CONSTRAINT):
+ this = self._parse_id_var()
- self._match_r_paren()
- else:
- return this
+ if self._match_texts(self.CONSTRAINT_PARSERS):
+ return self.expression(
+ exp.ColumnConstraint,
+ this=this,
+ kind=self.CONSTRAINT_PARSERS[self._prev.text.upper()](self),
+ )
- return self.expression(exp.ColumnConstraint, this=this, kind=kind)
+ return this
def _parse_constraint(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.CONSTRAINT):
- return self._parse_unnamed_constraint()
+ return self._parse_unnamed_constraint(constraints=self.SCHEMA_UNNAMED_CONSTRAINTS)
this = self._parse_id_var()
expressions = []
@@ -2823,12 +2948,21 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Constraint, this=this, expressions=expressions)
- def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]:
- if not self._match_set(self.CONSTRAINT_PARSERS):
+ def _parse_unnamed_constraint(
+ self, constraints: t.Optional[t.Collection[str]] = None
+ ) -> t.Optional[exp.Expression]:
+ if not self._match_texts(constraints or self.CONSTRAINT_PARSERS):
return None
- return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
+
+ constraint = self._prev.text.upper()
+ if constraint not in self.CONSTRAINT_PARSERS:
+ self.raise_error(f"No parser found for schema constraint {constraint}.")
+
+ return self.CONSTRAINT_PARSERS[constraint](self)
def _parse_unique(self) -> exp.Expression:
+ if not self._match(TokenType.L_PAREN, advance=False):
+ return self.expression(exp.UniqueColumnConstraint)
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
def _parse_key_constraint_options(self) -> t.List[str]:
@@ -2908,6 +3042,14 @@ class Parser(metaclass=_Parser):
)
def _parse_primary_key(self) -> exp.Expression:
+ desc = (
+ self._match_set((TokenType.ASC, TokenType.DESC))
+ and self._prev.token_type == TokenType.DESC
+ )
+
+ if not self._match(TokenType.L_PAREN, advance=False):
+ return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
+
expressions = self._parse_wrapped_id_vars()
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
@@ -3306,6 +3448,12 @@ class Parser(metaclass=_Parser):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
+ def _parse_parameter(self) -> exp.Expression:
+ wrapped = self._match(TokenType.L_BRACE)
+ this = self._parse_var() or self._parse_primary()
+ self._match(TokenType.R_BRACE)
+ return self.expression(exp.Parameter, this=this, wrapped=wrapped)
+
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PLACEHOLDER_PARSERS):
placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self)
@@ -3449,7 +3597,7 @@ class Parser(metaclass=_Parser):
if kind == TokenType.CONSTRAINT:
this = self._parse_id_var()
- if self._match(TokenType.CHECK):
+ if self._match_text_seq("CHECK"):
expression = self._parse_wrapped(self._parse_conjunction)
enforced = self._match_text_seq("ENFORCED")
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 8cf17a7..9b29c12 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -138,7 +138,6 @@ class TokenType(AutoName):
CASCADE = auto()
CASE = auto()
CHARACTER_SET = auto()
- CHECK = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMAND = auto()
@@ -164,7 +163,6 @@ class TokenType(AutoName):
DIV = auto()
DROP = auto()
ELSE = auto()
- ENCODE = auto()
END = auto()
ESCAPE = auto()
EXCEPT = auto()
@@ -182,17 +180,16 @@ class TokenType(AutoName):
FROM = auto()
FULL = auto()
FUNCTION = auto()
- GENERATED = auto()
GLOB = auto()
GLOBAL = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
HAVING = auto()
HINT = auto()
- IDENTITY = auto()
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
+ ILIKE_ANY = auto()
IN = auto()
INDEX = auto()
INNER = auto()
@@ -211,6 +208,7 @@ class TokenType(AutoName):
LEADING = auto()
LEFT = auto()
LIKE = auto()
+ LIKE_ANY = auto()
LIMIT = auto()
LOAD_DATA = auto()
LOCAL = auto()
@@ -253,6 +251,7 @@ class TokenType(AutoName):
RECURSIVE = auto()
REPLACE = auto()
RESPECT_NULLS = auto()
+ RETURNING = auto()
REFERENCES = auto()
RIGHT = auto()
RLIKE = auto()
@@ -260,7 +259,6 @@ class TokenType(AutoName):
ROLLUP = auto()
ROW = auto()
ROWS = auto()
- SCHEMA_COMMENT = auto()
SEED = auto()
SELECT = auto()
SEMI = auto()
@@ -441,7 +439,7 @@ class Tokenizer(metaclass=_Tokenizer):
KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
- for key in ("{{", "{%", "{#")
+ for key in ("{%", "{#")
for postfix in ("", "+", "-")
},
**{
@@ -449,6 +447,8 @@ class Tokenizer(metaclass=_Tokenizer):
for key in ("%}", "#}")
for prefix in ("", "+", "-")
},
+ "{{+": TokenType.BLOCK_START,
+ "{{-": TokenType.BLOCK_START,
"+}}": TokenType.BLOCK_END,
"-}}": TokenType.BLOCK_END,
"/*+": TokenType.HINT,
@@ -486,11 +486,9 @@ class Tokenizer(metaclass=_Tokenizer):
"CASE": TokenType.CASE,
"CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET,
- "CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
- "COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
"CONSTRAINT": TokenType.CONSTRAINT,
@@ -526,12 +524,10 @@ class Tokenizer(metaclass=_Tokenizer):
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
- "GENERATED": TokenType.GENERATED,
"GLOB": TokenType.GLOB,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
- "IDENTITY": TokenType.IDENTITY,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
@@ -747,11 +743,9 @@ class Tokenizer(metaclass=_Tokenizer):
"_prev_token_line",
"_prev_token_comments",
"_prev_token_type",
- "_replace_backslash",
)
def __init__(self) -> None:
- self._replace_backslash = "\\" in self._STRING_ESCAPES
self.reset()
def reset(self) -> None:
@@ -855,7 +849,7 @@ class Tokenizer(metaclass=_Tokenizer):
def _scan_keywords(self) -> None:
size = 0
word = None
- chars = self._text
+ chars: t.Optional[str] = self._text
char = chars
prev_space = False
skip = False
@@ -887,7 +881,7 @@ class Tokenizer(metaclass=_Tokenizer):
else:
skip = True
else:
- chars = None # type: ignore
+ chars = None
if not word:
if self._char in self.SINGLE_TOKENS:
@@ -1015,7 +1009,6 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance(len(quote))
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
- text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
@@ -1091,13 +1084,18 @@ class Tokenizer(metaclass=_Tokenizer):
delim_size = len(delimiter)
while True:
- if (
- self._char in self._STRING_ESCAPES
- and self._peek
- and (self._peek == delimiter or self._peek in self._STRING_ESCAPES)
+ if self._char in self._STRING_ESCAPES and (
+ self._peek == delimiter or self._peek in self._STRING_ESCAPES
):
- text += self._peek
- self._advance(2)
+ if self._peek == delimiter:
+ text += self._peek # type: ignore
+ else:
+ text += self._char + self._peek # type: ignore
+
+ if self._current + 1 < self.size:
+ self._advance(2)
+ else:
+ raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._current}")
else:
if self._chars(delim_size) == delimiter:
if delim_size > 1: