summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-08 04:14:34 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-08 04:14:34 +0000
commit8bec55350caa5c760d8b7e7e2d0ba6c77a32bc71 (patch)
treed6259e0351c7b4a50d528122513d533bb582eb2b /sqlglot
parentReleasing debian version 10.6.0-1. (diff)
downloadsqlglot-8bec55350caa5c760d8b7e7e2d0ba6c77a32bc71.tar.xz
sqlglot-8bec55350caa5c760d8b7e7e2d0ba6c77a32bc71.zip
Merging upstream version 10.6.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py66
-rw-r--r--sqlglot/dataframe/sql/column.py6
-rw-r--r--sqlglot/dataframe/sql/functions.py20
-rw-r--r--sqlglot/dialects/bigquery.py8
-rw-r--r--sqlglot/dialects/dialect.py10
-rw-r--r--sqlglot/dialects/drill.py5
-rw-r--r--sqlglot/dialects/duckdb.py12
-rw-r--r--sqlglot/dialects/hive.py16
-rw-r--r--sqlglot/dialects/mysql.py17
-rw-r--r--sqlglot/dialects/oracle.py4
-rw-r--r--sqlglot/dialects/postgres.py45
-rw-r--r--sqlglot/dialects/presto.py32
-rw-r--r--sqlglot/dialects/redshift.py11
-rw-r--r--sqlglot/dialects/snowflake.py9
-rw-r--r--sqlglot/dialects/spark.py37
-rw-r--r--sqlglot/dialects/tableau.py1
-rw-r--r--sqlglot/dialects/teradata.py8
-rw-r--r--sqlglot/dialects/tsql.py2
-rw-r--r--sqlglot/diff.py9
-rw-r--r--sqlglot/executor/__init__.py61
-rw-r--r--sqlglot/executor/env.py1
-rw-r--r--sqlglot/executor/table.py7
-rw-r--r--sqlglot/expressions.py158
-rw-r--r--sqlglot/generator.py187
-rw-r--r--sqlglot/lineage.py7
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py6
-rw-r--r--sqlglot/parser.py122
-rw-r--r--sqlglot/schema.py3
-rw-r--r--sqlglot/tokens.py1
31 files changed, 647 insertions, 228 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index bfcabb3..714897f 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -33,7 +33,13 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.6.0"
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
+ T = t.TypeVar("T", bound=Expression)
+
+
+__version__ = "10.6.3"
pretty = False
"""Whether to format generated SQL by default."""
@@ -42,9 +48,7 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
-def parse(
- sql: str, read: t.Optional[str | Dialect] = None, **opts
-) -> t.List[t.Optional[Expression]]:
+def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
"""
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
@@ -60,9 +64,57 @@ def parse(
return dialect.parse(sql, **opts)
+@t.overload
+def parse_one(
+ sql: str,
+ read: None = None,
+ into: t.Type[T] = ...,
+ **opts,
+) -> T:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ read: DialectType,
+ into: t.Type[T],
+ **opts,
+) -> T:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ read: None = None,
+ into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
+ **opts,
+) -> Expression:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ read: DialectType,
+ into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
+ **opts,
+) -> Expression:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ **opts,
+) -> Expression:
+ ...
+
+
def parse_one(
sql: str,
- read: t.Optional[str | Dialect] = None,
+ read: DialectType = None,
into: t.Optional[exp.IntoType] = None,
**opts,
) -> Expression:
@@ -96,8 +148,8 @@ def parse_one(
def transpile(
sql: str,
- read: t.Optional[str | Dialect] = None,
- write: t.Optional[str | Dialect] = None,
+ read: DialectType = None,
+ write: DialectType = None,
identity: bool = True,
error_level: t.Optional[ErrorLevel] = None,
**opts,
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index 40ffe3e..f5b0974 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -260,11 +260,7 @@ class Column:
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
- new_expression = exp.Cast(
- this=self.column_expression,
- to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
- )
- return Column(new_expression)
+ return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index a141fe4..47d5e7b 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column:
def dayofweek(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "DAYOFWEEK")
+ return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
def dayofmonth(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "DAYOFMONTH")
+ return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
def dayofyear(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "DAYOFYEAR")
+ return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
def hour(col: ColumnOrName) -> Column:
@@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column:
def weekofyear(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
+ return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
@@ -1144,10 +1144,16 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge)
if finish is not None:
finish_exp = _get_lambda_from_func(finish)
- return Column.invoke_anonymous_function(
- col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
+ return Column.invoke_expression_over_column(
+ col,
+ glotexp.Reduce,
+ initial=initialValue,
+ merge=Column(merge_exp),
+ finish=Column(finish_exp),
)
- return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
+ )
def transform(
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 27dca48..90ae229 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -222,14 +222,6 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING",
}
- ROOT_PROPERTIES = {
- exp.LanguageProperty,
- exp.ReturnsProperty,
- exp.VolatilityProperty,
- }
-
- WITH_PROPERTIES = {exp.Property}
-
EXPLICIT_UNION = True
def array_sql(self, expression: exp.Array) -> str:
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0c2beba..1b20e0a 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect):
def get_or_raise(cls, dialect):
if not dialect:
return cls
+ if isinstance(dialect, _Dialect):
+ return dialect
+ if isinstance(dialect, Dialect):
+ return dialect.__class__
+
result = cls.get(dialect)
if not result:
raise ValueError(f"Unknown dialect '{dialect}'")
+
return result
@classmethod
@@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect):
)
+if t.TYPE_CHECKING:
+ DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
+
+
def rename_func(name):
def _rename(self, expression):
args = flatten(expression.args.values())
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 4e3c0e1..d0a0251 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -137,7 +137,10 @@ class Drill(Dialect):
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
- ROOT_PROPERTIES = {exp.PartitionedByProperty}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ }
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 4646eb4..95ff95c 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -20,10 +20,6 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _unix_to_time(self, expression):
- return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
-
-
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
@@ -113,7 +109,7 @@ class DuckDB(Dialect):
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRUCT_PACK": exp.Struct.from_arg_list,
- "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
+ "TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
@@ -162,9 +158,9 @@ class DuckDB(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
- exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time,
- exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)",
+ exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
+ exp.UnixToTime: rename_func("TO_TIMESTAMP"),
+ exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 4bbec70..f2b6eaa 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -322,17 +322,11 @@ class Hive(Dialect):
exp.LastDateOfMonth: rename_func("LAST_DAY"),
}
- WITH_PROPERTIES = {exp.Property}
-
- ROOT_PROPERTIES = {
- exp.PartitionedByProperty,
- exp.FileFormatProperty,
- exp.SchemaCommentProperty,
- exp.LocationProperty,
- exp.TableFormatProperty,
- exp.RowFormatDelimitedProperty,
- exp.RowFormatSerdeProperty,
- exp.SerdeProperties,
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
}
def with_properties(self, properties):
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index cd8c30c..a5bd86b 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,7 +1,5 @@
from __future__ import annotations
-import typing as t
-
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@@ -98,6 +96,8 @@ def _date_add_sql(kind):
class MySQL(Dialect):
+ time_format = "'%Y-%m-%d %T'"
+
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
time_mapping = {
"%M": "%B",
@@ -110,6 +110,7 @@ class MySQL(Dialect):
"%u": "%W",
"%k": "%-H",
"%l": "%-I",
+ "%T": "%H:%M:%S",
}
class Tokenizer(tokens.Tokenizer):
@@ -428,6 +429,7 @@ class MySQL(Dialect):
)
class Generator(generator.Generator):
+ LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
@@ -449,23 +451,12 @@ class MySQL(Dialect):
exp.StrPosition: strposition_to_locate_sql,
}
- ROOT_PROPERTIES = {
- exp.EngineProperty,
- exp.AutoIncrementProperty,
- exp.CharacterSetProperty,
- exp.CollateProperty,
- exp.SchemaCommentProperty,
- exp.LikeProperty,
- }
-
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
- WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
-
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 67d791d..fde845e 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -44,6 +44,8 @@ class Oracle(Dialect):
}
class Generator(generator.Generator):
+ LOCKING_READS_SUPPORTED = True
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER",
@@ -69,6 +71,7 @@ class Oracle(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
+ exp.Substring: rename_func("SUBSTR"),
}
def query_modifiers(self, expression, *sqls):
@@ -90,6 +93,7 @@ class Oracle(Dialect):
self.sql(expression, "order"),
self.sql(expression, "offset"), # offset before limit in oracle
self.sql(expression, "limit"),
+ self.sql(expression, "lock"),
sep="",
)
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 0d74b3a..6418032 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -148,6 +148,22 @@ def _serial_to_generated(expression):
return expression
+def _generate_series(args):
+ # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
+ step = seq_get(args, 2)
+
+ if step is None:
+ # Postgres allows calls with just two arguments -- the "step" argument defaults to 1
+ return exp.GenerateSeries.from_arg_list(args)
+
+ if step.is_string:
+ args[2] = exp.to_interval(step.this)
+ elif isinstance(step, exp.Interval) and not step.args.get("unit"):
+ args[2] = exp.to_interval(step.this.this)
+
+ return exp.GenerateSeries.from_arg_list(args)
+
+
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
@@ -195,29 +211,6 @@ class Postgres(Dialect):
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
- CREATABLES = (
- "AGGREGATE",
- "CAST",
- "CONVERSION",
- "COLLATION",
- "DEFAULT CONVERSION",
- "CONSTRAINT",
- "DOMAIN",
- "EXTENSION",
- "FOREIGN",
- "FUNCTION",
- "OPERATOR",
- "POLICY",
- "ROLE",
- "RULE",
- "SEQUENCE",
- "TEXT",
- "TRIGGER",
- "TYPE",
- "UNLOGGED",
- "USER",
- )
-
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
@@ -243,8 +236,6 @@ class Postgres(Dialect):
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
- **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
- **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
@@ -257,8 +248,10 @@ class Postgres(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
+ "NOW": exp.CurrentTimestamp.from_arg_list,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
+ "GENERATE_SERIES": _generate_series,
}
BITWISE = {
@@ -272,6 +265,8 @@ class Postgres(Dialect):
}
class Generator(generator.Generator):
+ LOCKING_READS_SUPPORTED = True
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT",
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 8175d6f..6c1a474 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+def _sequence_sql(self, expression):
+ start = expression.args["start"]
+ end = expression.args["end"]
+ step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
+
+ target_type = None
+
+ if isinstance(start, exp.Cast):
+ target_type = start.to
+ elif isinstance(end, exp.Cast):
+ target_type = end.to
+
+ if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
+ to = target_type.copy()
+
+ if target_type is start.to:
+ end = exp.Cast(this=end, to=to)
+ else:
+ start = exp.Cast(this=start, to=to)
+
+ return f"SEQUENCE({self.format_args(start, end, step)})"
+
+
def _ensure_utf8(charset):
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
@@ -145,7 +168,7 @@ def _from_unixtime(args):
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
- time_format = "'%Y-%m-%d %H:%i:%S'"
+ time_format = MySQL.time_format # type: ignore
time_mapping = MySQL.time_mapping # type: ignore
class Tokenizer(tokens.Tokenizer):
@@ -197,7 +220,10 @@ class Presto(Dialect):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
- ROOT_PROPERTIES = {exp.SchemaCommentProperty}
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -223,6 +249,7 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@@ -231,6 +258,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
+ exp.GenerateSeries: _sequence_sql,
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 7da881f..c3c99eb 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -61,14 +61,9 @@ class Redshift(Postgres):
exp.DataType.Type.INT: "INTEGER",
}
- ROOT_PROPERTIES = {
- exp.DistKeyProperty,
- exp.SortKeyProperty,
- exp.DistStyleProperty,
- }
-
- WITH_PROPERTIES = {
- exp.LikeProperty,
+ PROPERTIES_LOCATION = {
+ **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
}
TRANSFORMS = {
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index db72a34..3b83b02 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -234,15 +234,6 @@ class Snowflake(Dialect):
"replace": "RENAME",
}
- ROOT_PROPERTIES = {
- exp.PartitionedByProperty,
- exp.ReturnsProperty,
- exp.LanguageProperty,
- exp.SchemaCommentProperty,
- exp.ExecuteAsProperty,
- exp.VolatilityProperty,
- }
-
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index fc711ab..8ef4a87 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -73,6 +73,19 @@ class Spark(Hive):
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
+ "AGGREGATE": exp.Reduce.from_arg_list,
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
}
FUNCTION_PARSERS = {
@@ -105,6 +118,14 @@ class Spark(Hive):
exp.DataType.Type.BIGINT: "LONG",
}
+ PROPERTIES_LOCATION = {
+ **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
@@ -126,11 +147,27 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
+ exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
+ def cast_sql(self, expression: exp.Cast) -> str:
+ if isinstance(expression.this, exp.Cast) and expression.this.is_type(
+ exp.DataType.Type.JSON
+ ):
+ schema = f"'{self.sql(expression, 'to')}'"
+ return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
+ if expression.to.is_type(exp.DataType.Type.JSON):
+ return f"TO_JSON({self.sql(expression, 'this')})"
+
+ return super(Spark.Generator, self).cast_sql(expression)
+
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 36c085f..31b1c8d 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -31,6 +31,5 @@ class Tableau(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
- "IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 4340820..123da04 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -76,6 +76,14 @@ class Teradata(Dialect):
)
class Generator(generator.Generator):
+ PROPERTIES_LOCATION = {
+ **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
+ }
+
+ def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
+ return f"PARTITION BY {self.sql(expression, 'this')}"
+
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def update_sql(self, expression: exp.Update) -> str:
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 9f9099e..05ba53a 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -412,6 +412,8 @@ class TSQL(Dialect):
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
class Generator(generator.Generator):
+ LOCKING_READS_SUPPORTED = True
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT",
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index a5373b0..7d5ec21 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -14,10 +14,6 @@ from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection
-if t.TYPE_CHECKING:
- T = t.TypeVar("T")
- Edit = t.Union[Insert, Remove, Move, Update, Keep]
-
@dataclass(frozen=True)
class Insert:
@@ -56,6 +52,11 @@ class Keep:
target: exp.Expression
+if t.TYPE_CHECKING:
+ T = t.TypeVar("T")
+ Edit = t.Union[Insert, Remove, Move, Update, Keep]
+
+
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index 04621b5..67b4b00 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -1,5 +1,13 @@
+"""
+.. include:: ../../posts/python_sql_engine.md
+----
+"""
+
+from __future__ import annotations
+
import logging
import time
+import typing as t
from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
@@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot")
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+ from sqlglot.executor.table import Tables
+ from sqlglot.expressions import Expression
+ from sqlglot.schema import Schema
-def execute(sql, schema=None, read=None, tables=None):
+
+def execute(
+ sql: str | Expression,
+ schema: t.Optional[t.Dict | Schema] = None,
+ read: DialectType = None,
+ tables: t.Optional[t.Dict] = None,
+) -> Table:
"""
Run a sql query against data.
Args:
- sql (str|sqlglot.Expression): a sql statement
- schema (dict|sqlglot.optimizer.Schema): database schema.
- This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
- the following forms:
- 1. {table: {col: type}}
- 2. {db: {table: {col: type}}}
- 3. {catalog: {db: {table: {col: type}}}}
- read (str): the SQL dialect to apply during parsing
- (eg. "spark", "hive", "presto", "mysql").
- tables (dict): additional tables to register.
+ sql: a sql statement.
+ schema: database schema.
+ This can either be an instance of `Schema` or a mapping in one of the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
+ tables: additional tables to register.
+
Returns:
- sqlglot.executor.Table: Simple columnar data structure.
+ Simple columnar data structure.
"""
- tables = ensure_tables(tables)
+ tables_ = ensure_tables(tables)
+
if not schema:
schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns}
- for name, table in tables.mapping.items()
+ for name, table in tables_.mapping.items()
}
+
schema = ensure_schema(schema)
- if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
+
+ if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
+
expression = maybe_parse(sql, dialect=read)
+
now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True)
+
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
+
plan = Plan(expression)
+
logger.debug("Logical Plan: %s", plan)
+
now = time.time()
- result = PythonExecutor(tables=tables).execute(plan)
+ result = PythonExecutor(tables=tables_).execute(plan)
+
logger.debug("Query finished: %f", time.time() - now)
+
return result
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 04dc938..ba9cbbd 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -171,5 +171,6 @@ ENV = {
"STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring,
+ "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
"UPPER": null_if_any(lambda arg: arg.upper()),
}
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index f1b5b54..27e3e5e 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema
@@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]):
pass
-def ensure_tables(d: dict | None) -> Tables:
+def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
return Tables(_ensure_tables(d))
-def _ensure_tables(d: dict | None) -> dict:
+def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
if not d:
return {}
@@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict:
columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows)
+
return result
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 7c1a116..6bb083a 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -32,13 +32,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
- from sqlglot.dialects.dialect import Dialect
-
- IntoType = t.Union[
- str,
- t.Type[Expression],
- t.Collection[t.Union[str, t.Type[Expression]]],
- ]
+ from sqlglot.dialects.dialect import DialectType
class _Expression(type):
@@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self._to_s()
- def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
+ def sql(self, dialect: DialectType = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
return load(obj)
+if t.TYPE_CHECKING:
+ IntoType = t.Union[
+ str,
+ t.Type[Expression],
+ t.Collection[t.Union[str, t.Type[Expression]]],
+ ]
+
+
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
"""
@@ -1285,6 +1287,18 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
+class AlgorithmProperty(Property):
+ arg_types = {"this": True}
+
+
+class DefinerProperty(Property):
+ arg_types = {"this": True}
+
+
+class SqlSecurityProperty(Property):
+ arg_types = {"definer": True}
+
+
class TableFormatProperty(Property):
arg_types = {"this": True}
@@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
class Properties(Expression):
- arg_types = {"expressions": True, "before": False}
+ arg_types = {"expressions": True}
NAME_TO_PROPERTY = {
+ "ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
+ "DEFINER": DefinerProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
@@ -1447,6 +1463,14 @@ class Properties(Expression):
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
+ class Location(AutoName):
+ POST_CREATE = auto()
+ PRE_SCHEMA = auto()
+ POST_INDEX = auto()
+ POST_SCHEMA_ROOT = auto()
+ POST_SCHEMA_WITH = auto()
+ UNSUPPORTED = auto()
+
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
@@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
"order": False,
"limit": False,
"offset": False,
+ "lock": False,
}
@@ -1713,6 +1738,12 @@ class Schema(Expression):
arg_types = {"this": False, "expressions": False}
+# Used to represent the FOR UPDATE and FOR SHARE locking read types.
+# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
+class Lock(Expression):
+ arg_types = {"update": True}
+
+
class Select(Subqueryable):
arg_types = {
"with": False,
@@ -2243,6 +2274,30 @@ class Select(Subqueryable):
properties=properties_expression,
)
+ def lock(self, update: bool = True, copy: bool = True) -> Select:
+ """
+ Set the locking read mode for this expression.
+
+ Examples:
+ >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
+ "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
+
+ >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
+ "SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
+
+ Args:
+ update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
+ copy: if `False`, modify this expression instance in-place.
+
+ Returns:
+ The modified expression.
+ """
+
+ inst = _maybe_copy(self, copy)
+ inst.set("lock", Lock(update=update))
+
+ return inst
+
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@@ -2456,24 +2511,28 @@ class DataType(Expression):
@classmethod
def build(
- cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
+ cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
) -> DataType:
from sqlglot import parse_one
if isinstance(dtype, str):
- data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__:
- data_type_exp = DataType(this=DataType.Type[dtype.upper()])
+ data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
+ elif isinstance(dtype, DataType):
+ return dtype
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs})
+ def is_type(self, dtype: DataType.Type) -> bool:
+ return self.this == dtype
+
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
@@ -2840,6 +2899,10 @@ class Array(Func):
is_var_len_args = True
+class GenerateSeries(Func):
+ arg_types = {"start": True, "end": True, "step": False}
+
+
class ArrayAgg(AggFunc):
pass
@@ -2909,6 +2972,9 @@ class Cast(Func):
def output_name(self):
return self.name
+ def is_type(self, dtype: DataType.Type) -> bool:
+ return self.to.is_type(dtype)
+
class Collate(Binary):
pass
@@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
+class DayOfWeek(Func):
+ _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
+
+
+class DayOfMonth(Func):
+ _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
+
+
+class DayOfYear(Func):
+ _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
+
+
+class WeekOfYear(Func):
+ _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
+
+
class LastDateOfMonth(Func):
pass
@@ -3239,7 +3321,7 @@ class ReadCSV(Func):
class Reduce(Func):
- arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
+ arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
class RegexpLike(Func):
@@ -3476,7 +3558,7 @@ def maybe_parse(
sql_or_expression: str | Expression,
*,
into: t.Optional[IntoType] = None,
- dialect: t.Optional[str] = None,
+ dialect: DialectType = None,
prefix: t.Optional[str] = None,
**opts,
) -> Expression:
@@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
+INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
+
+
+def to_interval(interval: str | Literal) -> Interval:
+ """Builds an interval expression from a string like '1 day' or '5 months'."""
+ if isinstance(interval, Literal):
+ if not interval.is_string:
+ raise ValueError("Invalid interval string.")
+
+ interval = interval.this
+
+ interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
+
+ if not interval_parts:
+ raise ValueError("Invalid interval string.")
+
+ return Interval(
+ this=Literal.string(interval_parts.group(1)),
+ unit=Var(this=interval_parts.group(2)),
+ )
+
+
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
def subquery(expression, alias=None, dialect=None, **opts):
"""
Build a subquery expression.
- Expample:
+
+ Example:
>>> subquery('select x from tbl', 'bar').select('x').sql()
'SELECT x FROM (SELECT x FROM tbl) AS bar'
@@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
+
Args:
col (str | Expression): column name
table (str | Expression): table name
@@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
)
+def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
+ """Cast an expression to a data type.
+
+ Example:
+ >>> cast('x + 1', 'int').sql()
+ 'CAST(x + 1 AS INT)'
+
+ Args:
+ expression: The expression to cast.
+ to: The datatype to cast to.
+
+ Returns:
+ A cast node.
+ """
+ expression = maybe_parse(expression, **opts)
+ return Cast(this=expression, to=DataType.build(to, **opts))
+
+
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
@@ -4137,7 +4261,7 @@ def values(
types = list(columns.values())
expressions[0].set(
"expressions",
- [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
+ [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values(
expressions=expressions,
@@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
return expression.transform(_expand, copy=copy)
-def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
+def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
"""
Returns a Func expression.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 3f3365a..b95e9bc 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -67,6 +67,7 @@ class Generator:
exp.VolatilityProperty: lambda self, e: e.name,
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
+ exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@@ -75,6 +76,9 @@ class Generator:
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
+ # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
+ LOCKING_READS_SUPPORTED = False
+
# Always do union distinct or union all
EXPLICIT_UNION = False
@@ -99,34 +103,42 @@ class Generator:
STRUCT_DELIMITER = ("<", ">")
- BEFORE_PROPERTIES = {
- exp.FallbackProperty,
- exp.WithJournalTableProperty,
- exp.LogProperty,
- exp.JournalProperty,
- exp.AfterJournalProperty,
- exp.ChecksumProperty,
- exp.FreespaceProperty,
- exp.MergeBlockRatioProperty,
- exp.DataBlocksizeProperty,
- exp.BlockCompressionProperty,
- exp.IsolatedLoadingProperty,
- }
-
- ROOT_PROPERTIES = {
- exp.ReturnsProperty,
- exp.LanguageProperty,
- exp.DistStyleProperty,
- exp.DistKeyProperty,
- exp.SortKeyProperty,
- exp.LikeProperty,
- }
-
- WITH_PROPERTIES = {
- exp.Property,
- exp.FileFormatProperty,
- exp.PartitionedByProperty,
- exp.TableFormatProperty,
+ PROPERTIES_LOCATION = {
+ exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
+ exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
+ exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
+ exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
+ exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
+ exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
+ exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
+ exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
+ exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
+ exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
@@ -284,10 +296,10 @@ class Generator:
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
- def no_identify(self, func: t.Callable[[], str]) -> str:
+ def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
original = self.identify
self.identify = False
- result = func()
+ result = func(*args, **kwargs)
self.identify = original
return result
@@ -455,19 +467,33 @@ class Generator:
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
- has_before_properties = expression.args.get("properties")
- has_before_properties = (
- has_before_properties.args.get("before") if has_before_properties else None
- )
- if kind == "TABLE" and has_before_properties:
+ properties = expression.args.get("properties")
+ properties_exp = expression.copy()
+ properties_locs = self.locate_properties(properties) if properties else {}
+ if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
+ exp.Properties.Location.POST_SCHEMA_WITH
+ ):
+ properties_exp.set(
+ "properties",
+ exp.Properties(
+ expressions=[
+ *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
+ *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
+ ]
+ ),
+ )
+ if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
this_name = self.sql(expression.this, "this")
- this_properties = self.sql(expression, "properties")
+ this_properties = self.properties(
+ exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
+ wrapped=False,
+ )
this_schema = f"({self.expressions(expression.this)})"
this = f"{this_name}, {this_properties} {this_schema}"
- properties = ""
+ properties_sql = ""
else:
this = self.sql(expression, "this")
- properties = self.sql(expression, "properties")
+ properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
@@ -514,11 +540,31 @@ class Generator:
if index.args.get("columns")
else ""
)
+ if index.args.get("primary") and properties_locs.get(
+ exp.Properties.Location.POST_INDEX
+ ):
+ postindex_props_sql = self.properties(
+ exp.Properties(
+ expressions=properties_locs[exp.Properties.Location.POST_INDEX]
+ ),
+ wrapped=False,
+ )
+ ind_columns = f"{ind_columns} {postindex_props_sql}"
+
indexes_sql.append(
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
)
index_sql = "".join(indexes_sql)
+ postcreate_props_sql = ""
+ if properties_locs.get(exp.Properties.Location.POST_CREATE):
+ postcreate_props_sql = self.properties(
+ exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
+ sep=" ",
+ prefix=" ",
+ wrapped=False,
+ )
+
modifiers = "".join(
(
replace,
@@ -531,6 +577,7 @@ class Generator:
multiset,
global_temporary,
volatile,
+ postcreate_props_sql,
)
)
no_schema_binding = (
@@ -539,7 +586,7 @@ class Generator:
post_expression_modifiers = "".join((data, statistics, no_primary_index))
- expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@@ -665,24 +712,19 @@ class Generator:
return f"PARTITION({self.expressions(expression)})"
def properties_sql(self, expression: exp.Properties) -> str:
- before_properties = []
root_properties = []
with_properties = []
for p in expression.expressions:
- p_class = p.__class__
- if p_class in self.BEFORE_PROPERTIES:
- before_properties.append(p)
- elif p_class in self.WITH_PROPERTIES:
+ p_loc = self.PROPERTIES_LOCATION[p.__class__]
+ if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
with_properties.append(p)
- elif p_class in self.ROOT_PROPERTIES:
+ elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
root_properties.append(p)
- return (
- self.properties(exp.Properties(expressions=before_properties), before=True)
- + self.root_properties(exp.Properties(expressions=root_properties))
- + self.with_properties(exp.Properties(expressions=with_properties))
- )
+ return self.root_properties(
+ exp.Properties(expressions=root_properties)
+ ) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
@@ -695,17 +737,41 @@ class Generator:
prefix: str = "",
sep: str = ", ",
suffix: str = "",
- before: bool = False,
+ wrapped: bool = True,
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- expressions = expressions if before else self.wrap(expressions)
+ expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH"))
+ def locate_properties(
+ self, properties: exp.Properties
+ ) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
+ properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
+ key: [] for key in exp.Properties.Location
+ }
+
+ for p in properties.expressions:
+ p_loc = self.PROPERTIES_LOCATION[p.__class__]
+ if p_loc == exp.Properties.Location.PRE_SCHEMA:
+ properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
+ elif p_loc == exp.Properties.Location.POST_INDEX:
+ properties_locs[exp.Properties.Location.POST_INDEX].append(p)
+ elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
+ properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
+ elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
+ properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
+ elif p_loc == exp.Properties.Location.POST_CREATE:
+ properties_locs[exp.Properties.Location.POST_CREATE].append(p)
+ elif p_loc == exp.Properties.Location.UNSUPPORTED:
+ self.unsupported(f"Unsupported property {p.key}")
+
+ return properties_locs
+
def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__
if property_cls == exp.Property:
@@ -713,7 +779,7 @@ class Generator:
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name:
- self.unsupported(f"Unsupported property {property_name}")
+ self.unsupported(f"Unsupported property {expression.key}")
return f"{property_name}={self.sql(expression, 'this')}"
@@ -975,7 +1041,7 @@ class Generator:
rollup = self.expressions(expression, key="rollup", indent=False)
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
- return f"{group_by}{grouping_sets}{cube}{rollup}"
+ return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
@@ -1015,7 +1081,7 @@ class Generator:
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
- return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
+ return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
@@ -1043,6 +1109,14 @@ class Generator:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
+ def lock_sql(self, expression: exp.Lock) -> str:
+ if self.LOCKING_READS_SUPPORTED:
+ lock_type = "UPDATE" if expression.args["update"] else "SHARE"
+ return self.seg(f"FOR {lock_type}")
+
+ self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
+ return ""
+
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
@@ -1163,6 +1237,7 @@ class Generator:
self.sql(expression, "order"),
self.sql(expression, "limit"),
self.sql(expression, "offset"),
+ self.sql(expression, "lock"),
sep="",
)
@@ -1773,7 +1848,7 @@ class Generator:
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
- expressions = self.no_identify(lambda: self.expressions(expression))
+ expressions = self.no_identify(self.expressions, expression)
expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 4e7eab8..a39ad8c 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
@dataclass(frozen=True)
class Node:
@@ -36,7 +39,7 @@ def lineage(
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
- dialect: t.Optional[str] = None,
+ dialect: DialectType = None,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
@@ -126,7 +129,7 @@ class LineageHTML:
def __init__(
self,
node: Node,
- dialect: t.Optional[str] = None,
+ dialect: DialectType = None,
imports: bool = True,
**opts: t.Any,
):
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 2245cc2..c6bea5a 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken):
taken[alias] = scope
# Try to maintain the selections
- expressions = scope.expression.args.get("expressions")
+ expressions = scope.selects
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
for e in expressions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 5a3ed5a..badbb87 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -300,7 +300,7 @@ class Scope:
list[exp.Expression]: expressions
"""
if isinstance(self.expression, exp.Union):
- return []
+ return self.expression.unnest().selects
return self.expression.selects
@property
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index f560760..f80484d 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -456,8 +456,10 @@ def extract_interval(interval):
def date_literal(date):
- expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
- return exp.Cast(this=exp.Literal.string(date), to=expr_type)
+ return exp.cast(
+ exp.Literal.string(date),
+ "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
+ )
def boolean_literal(condition):
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 6229105..e2b2c54 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -80,6 +80,7 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": parse_var_map,
+ "IFNULL": exp.Coalesce.from_arg_list,
}
NO_PAREN_FUNCTIONS = {
@@ -567,6 +568,8 @@ class Parser(metaclass=_Parser):
default=self._prev.text.upper() == "DEFAULT"
),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
+ "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
+ "DEFINER": lambda self: self._parse_definer(),
}
CONSTRAINT_PARSERS = {
@@ -608,6 +611,7 @@ class Parser(metaclass=_Parser):
"order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(),
+ "lock": lambda self: self._parse_lock(),
}
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
@@ -850,7 +854,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
- return self.sql[self._find_token(start) : self._find_token(end)]
+ return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
def _find_token(self, token: Token) -> int:
line = 1
@@ -901,6 +905,7 @@ class Parser(metaclass=_Parser):
return expression
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
+ start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
@@ -908,8 +913,7 @@ class Parser(metaclass=_Parser):
if default_kind:
kind = default_kind
else:
- self.raise_error(f"Expected {self.CREATABLES}")
- return None
+ return self._parse_as_command(start)
return self.expression(
exp.Drop,
@@ -929,6 +933,7 @@ class Parser(metaclass=_Parser):
)
def _parse_create(self) -> t.Optional[exp.Expression]:
+ start = self._prev
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata
@@ -943,16 +948,19 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
+ properties = None
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
- self.raise_error(f"Expected {self.CREATABLES}")
- return None
+ properties = self._parse_properties()
+ create_token = self._match_set(self.CREATABLES) and self._prev
+
+ if not properties or not create_token:
+ return self._parse_as_command(start)
exists = self._parse_exists(not_=True)
this = None
expression = None
- properties = None
data = None
statistics = None
no_primary_index = None
@@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser):
indexes = []
while True:
index = self._parse_create_table_index()
+
+ # post index PARTITION BY property
+ if self._match(TokenType.PARTITION_BY, advance=False):
+ if properties:
+ properties.expressions.append(self._parse_property())
+ else:
+ properties = self._parse_properties()
+
if not index:
break
else:
@@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser):
)
def _parse_property_before(self) -> t.Optional[exp.Expression]:
+ self._match(TokenType.COMMA)
+
+ # parsers look to _prev for no/dual/default, so need to consume first
self._match_text_seq("NO")
self._match_text_seq("DUAL")
self._match_text_seq("DEFAULT")
@@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
+ if self._match_text_seq("SQL", "SECURITY"):
+ return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER"))
+
assignment = self._match_pair(
TokenType.VAR, TokenType.EQ, advance=False
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
@@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser):
while True:
if before:
- self._match(TokenType.COMMA)
identified_property = self._parse_property_before()
else:
identified_property = self._parse_property()
@@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser):
properties.append(p)
if properties:
- return self.expression(exp.Properties, expressions=properties, before=before)
+ return self.expression(exp.Properties, expressions=properties)
return None
@@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser):
return self._parse_withisolatedloading()
+ # https://dev.mysql.com/doc/refman/8.0/en/create-view.html
+ def _parse_definer(self) -> t.Optional[exp.Expression]:
+ self._match(TokenType.EQ)
+
+ user = self._parse_id_var()
+ self._match(TokenType.PARAMETER)
+ host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text)
+
+ if not user or not host:
+ return None
+
+ return exp.DefinerProperty(this=f"{user}@{host}")
+
def _parse_withjournaltable(self) -> exp.Expression:
self._match_text_seq("WITH", "JOURNAL", "TABLE")
self._match(TokenType.EQ)
@@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser):
paren += 1
if self._curr.token_type == TokenType.R_PAREN:
paren -= 1
+ end = self._prev
self._advance()
if paren > 0:
self.raise_error("Expecting )", self._curr)
- if not self._curr:
- self.raise_error("Expecting pattern", self._curr)
- end = self._prev
pattern = exp.Var(this=self._find_sql(start, end))
else:
pattern = None
@@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_conjunction)
grouping_sets = self._parse_grouping_sets()
+ self._match(TokenType.COMMA)
with_ = self._match(TokenType.WITH)
- cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars())
- rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars())
+ cube = self._match(TokenType.CUBE) and (
+ with_ or self._parse_wrapped_csv(self._parse_column)
+ )
+
+ self._match(TokenType.COMMA)
+ rollup = self._match(TokenType.ROLLUP) and (
+ with_ or self._parse_wrapped_csv(self._parse_column)
+ )
return self.expression(
exp.Group,
@@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
+ def _parse_lock(self) -> t.Optional[exp.Expression]:
+ if self._match_text_seq("FOR", "UPDATE"):
+ return self.expression(exp.Lock, update=True)
+ if self._match_text_seq("FOR", "SHARE"):
+ return self.expression(exp.Lock, update=False)
+
+ return None
+
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
return this
@@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser):
maybe_func = True
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- return exp.DataType(
+ this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
nested=True,
)
+ while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
+ this = exp.DataType(
+ this=exp.DataType.Type.ARRAY,
+ expressions=[this],
+ nested=True,
+ )
+
+ return this
+
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
@@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected type")
elif op:
self._advance()
- field = exp.Literal.string(self._prev.text)
+ value = self._prev.text
+ field = (
+ exp.Literal.number(value)
+ if self._prev.token_type == TokenType.NUMBER
+ else exp.Literal.string(value)
+ )
else:
field = self._parse_star() or self._parse_function() or self._parse_id_var()
@@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser):
if not self._curr:
break
- if self._match_text_seq("NOT", "ENFORCED"):
+ if self._match(TokenType.ON):
+ action = None
+ on = self._advance_any() and self._prev.text
+
+ if self._match(TokenType.NO_ACTION):
+ action = "NO ACTION"
+ elif self._match(TokenType.CASCADE):
+ action = "CASCADE"
+ elif self._match_pair(TokenType.SET, TokenType.NULL):
+ action = "SET NULL"
+ elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
+ action = "SET DEFAULT"
+ else:
+ self.raise_error("Invalid key constraint")
+
+ options.append(f"ON {on} {action}")
+ elif self._match_text_seq("NOT", "ENFORCED"):
options.append("NOT ENFORCED")
elif self._match_text_seq("DEFERRABLE"):
options.append("DEFERRABLE")
@@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser):
options.append("NORELY")
elif self._match_text_seq("MATCH", "FULL"):
options.append("MATCH FULL")
- elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
- options.append("ON UPDATE NO ACTION")
- elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
- options.append("ON DELETE NO ACTION")
else:
break
@@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser):
prefix += self._prev.text
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
- return exp.Identifier(this=prefix + self._prev.text, quoted=False)
+ quoted = self._prev.token_type == TokenType.STRING
+ return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
+
return None
def _parse_string(self) -> t.Optional[exp.Expression]:
@@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser):
def _parse_set(self) -> exp.Expression:
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
+ def _parse_as_command(self, start: Token) -> exp.Command:
+ while self._curr:
+ self._advance()
+ return exp.Command(this=self._find_sql(start, self._prev))
+
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f6f3883..f5d9f2b 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
+ from sqlglot.dialects.dialect import DialectType
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
@@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
- dialect: t.Optional[str] = None,
+ dialect: DialectType = None,
) -> None:
self.dialect = dialect
self.visible = visible or {}
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 8bdd338..e95057a 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
+ "LONGVARCHAR": TokenType.TEXT,
"BINARY": TokenType.BINARY,
"BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY,