From a8b22b4c5bdf9139a187c92b7b9f81bdeaa84888 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 27 Feb 2023 11:46:36 +0100 Subject: Merging upstream version 11.2.3. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 55 +++++---- sqlglot/dataframe/sql/functions.py | 3 +- sqlglot/dialects/bigquery.py | 6 +- sqlglot/dialects/mysql.py | 5 + sqlglot/dialects/postgres.py | 5 + sqlglot/dialects/snowflake.py | 10 +- sqlglot/dialects/spark.py | 7 +- sqlglot/dialects/teradata.py | 1 + sqlglot/diff.py | 55 +++++++-- sqlglot/expressions.py | 227 +++++++++++++++++++++++-------------- sqlglot/generator.py | 102 ++++++++--------- sqlglot/parser.py | 212 +++++++++++++++++++--------------- sqlglot/serde.py | 13 ++- sqlglot/tokens.py | 53 +++++---- 14 files changed, 450 insertions(+), 304 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 7bcaa22..87b36b0 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -9,38 +9,45 @@ from __future__ import annotations import typing as t from sqlglot import expressions as exp -from sqlglot.dialects import Dialect, Dialects -from sqlglot.diff import diff -from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError -from sqlglot.expressions import Expression -from sqlglot.expressions import alias_ as alias +from sqlglot.dialects.dialect import Dialect as Dialect, Dialects as Dialects +from sqlglot.diff import diff as diff +from sqlglot.errors import ( + ErrorLevel as ErrorLevel, + ParseError as ParseError, + TokenError as TokenError, + UnsupportedError as UnsupportedError, +) from sqlglot.expressions import ( - and_, - column, - condition, - except_, - from_, - intersect, - maybe_parse, - not_, - or_, - select, - subquery, + Expression as Expression, + alias_ as alias, + and_ as and_, + column as column, + condition as condition, + except_ as except_, + from_ as from_, + intersect as intersect, + maybe_parse as maybe_parse, + not_ as not_, + or_ as or_, + select as select, + subquery as subquery, + table_ as table, + to_column as to_column, + to_table as to_table, + union as union, ) -from sqlglot.expressions import table_ as table -from sqlglot.expressions import to_column, to_table, union -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.schema import MappingSchema, Schema -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.generator import Generator as Generator +from sqlglot.parser import Parser as Parser +from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema +from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType + from sqlglot.dialects.dialect import DialectType as DialectType T = t.TypeVar("T", bound=Expression) -__version__ = "11.2.0" +__version__ = "11.2.3" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 0262d54..8f24746 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -4,8 +4,7 @@ import typing as t from sqlglot import exp as expression from sqlglot.dataframe.sql.column import Column -from sqlglot.helper import ensure_list -from sqlglot.helper import flatten as _flatten +from sqlglot.helper import ensure_list, flatten as _flatten if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index a75e802..32b5147 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -38,7 +38,10 @@ def _date_add_sql( ) -> t.Callable[[generator.Generator, exp.Expression], str]: def func(self, expression): this = self.sql(expression, "this") - return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})" + unit = expression.args.get("unit") + unit = exp.var(unit.name.upper() if unit else "DAY") + interval = exp.Interval(this=expression.expression, unit=unit) + return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func @@ -235,6 +238,7 @@ class BigQuery(Dialect): exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), + exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.VariancePop: rename_func("VAR_POP"), exp.Values: _derived_table_values_to_unnest, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 235eb77..836bf3c 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -462,6 +462,11 @@ class MySQL(Dialect): TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, + } + def show_sql(self, expression): this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 7612330..3507cb5 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -318,3 +318,8 @@ class Postgres(Dialect): if isinstance(seq_get(e.expressions, 0), exp.Select) else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", } + + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9342865..5931364 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -150,6 +150,10 @@ class Snowflake(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "DATE_TRUNC": lambda args: exp.DateTrunc( + unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore + this=seq_get(args, 1), + ), "IFF": exp.If.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, @@ -215,7 +219,6 @@ class Snowflake(Dialect): } class Generator(generator.Generator): - CREATE_TRANSIENT = True PARAMETER_TOKEN = "$" TRANSFORMS = { @@ -252,6 +255,11 @@ class Snowflake(Dialect): "replace": "RENAME", } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.SetProperty: exp.Properties.Location.UNSUPPORTED, + } + def ilikeany_sql(self, expression: exp.ILikeAny) -> str: return self.binary(expression, "ILIKE ANY") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index dd3e0c8..05ee53f 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -8,9 +8,12 @@ from sqlglot.helper import seq_get def _create_sql(self, e): kind = e.args.get("kind") - temporary = e.args.get("temporary") + properties = e.args.get("properties") - if kind.upper() == "TABLE" and temporary is True: + if kind.upper() == "TABLE" and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ): return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" return create_with_partitions_sql(self, e) diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index e3eec71..7953bc5 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -114,6 +114,7 @@ class Teradata(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, + exp.VolatilityProperty: exp.Properties.Location.POST_CREATE, } def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 7530613..dddb9ad 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -11,8 +11,7 @@ from collections import defaultdict from dataclasses import dataclass from heapq import heappop, heappush -from sqlglot import Dialect -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.helper import ensure_collection @@ -58,7 +57,12 @@ if t.TYPE_CHECKING: Edit = t.Union[Insert, Remove, Move, Update, Keep] -def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: +def diff( + source: exp.Expression, + target: exp.Expression, + matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + **kwargs: t.Any, +) -> t.List[Edit]: """ Returns the list of changes between the source and the target expressions. @@ -80,13 +84,38 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: Args: source: the source expression. target: the target expression against which the diff should be calculated. + matchings: the list of pre-matched node pairs which is used to help the algorithm's + heuristics produce better results for subtrees that are known by a caller to be matching. + Note: expression references in this list must refer to the same node objects that are + referenced in source / target trees. Returns: the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. This list represents a sequence of steps needed to transform the source expression tree into the target one. """ - return ChangeDistiller().diff(source.copy(), target.copy()) + matchings = matchings or [] + matching_ids = {id(n) for pair in matchings for n in pair} + + def compute_node_mappings( + original: exp.Expression, copy: exp.Expression + ) -> t.Dict[int, exp.Expression]: + return { + id(old_node): new_node + for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk()) + if id(old_node) in matching_ids + } + + source_copy = source.copy() + target_copy = target.copy() + + node_mappings = { + **compute_node_mappings(source, source_copy), + **compute_node_mappings(target, target_copy), + } + matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings] + + return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy) LEAF_EXPRESSION_TYPES = ( @@ -109,16 +138,26 @@ class ChangeDistiller: self.t = t self._sql_generator = Dialect().generator() - def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]: + def diff( + self, + source: exp.Expression, + target: exp.Expression, + matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + ) -> t.List[Edit]: + matchings = matchings or [] + pre_matched_nodes = {id(s): id(t) for s, t in matchings} + if len({n for pair in pre_matched_nodes.items() for n in pair}) != 2 * len(matchings): + raise ValueError("Each node can be referenced at most once in the list of matchings") + self._source = source self._target = target self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()} - self._unmatched_source_nodes = set(self._source_index) - self._unmatched_target_nodes = set(self._target_index) + self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) + self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} - matching_set = self._compute_matching_set() + matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()} return self._generate_edit_script(matching_set) def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a29aeb4..59881d6 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -82,7 +82,7 @@ class Expression(metaclass=_Expression): key = "expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "comments", "_type") + __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta") def __init__(self, **args: t.Any): self.args: t.Dict[str, t.Any] = args @@ -90,6 +90,7 @@ class Expression(metaclass=_Expression): self.arg_key: t.Optional[str] = None self.comments: t.Optional[t.List[str]] = None self._type: t.Optional[DataType] = None + self._meta: t.Optional[t.Dict[str, t.Any]] = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -219,10 +220,23 @@ class Expression(metaclass=_Expression): dtype = DataType.build(dtype) self._type = dtype # type: ignore + @property + def meta(self) -> t.Dict[str, t.Any]: + if self._meta is None: + self._meta = {} + return self._meta + def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) - copy.comments = self.comments - copy.type = self.type + if self.comments is not None: + copy.comments = deepcopy(self.comments) + + if self._type is not None: + copy._type = self._type.copy() + + if self._meta is not None: + copy._meta = deepcopy(self._meta) + return copy def copy(self): @@ -329,6 +343,15 @@ class Expression(metaclass=_Expression): """ return self.find_ancestor(Select) + def root(self) -> Expression: + """ + Returns the root expression of this tree. + """ + expression = self + while expression.parent: + expression = expression.parent + return expression + def walk(self, bfs=True, prune=None): """ Returns a generator object which visits all nodes in this tree. @@ -767,21 +790,10 @@ class Create(Expression): "this": True, "kind": True, "expression": False, - "set": False, - "multiset": False, - "global_temporary": False, - "volatile": False, "exists": False, "properties": False, - "temporary": False, - "transient": False, - "external": False, "replace": False, "unique": False, - "materialized": False, - "data": False, - "statistics": False, - "no_primary_index": False, "indexes": False, "no_schema_binding": False, "begin": False, @@ -1336,47 +1348,47 @@ class Property(Expression): arg_types = {"this": True, "value": True} -class AlgorithmProperty(Property): - arg_types = {"this": True} +class AfterJournalProperty(Property): + arg_types = {"no": True, "dual": False, "local": False} -class DefinerProperty(Property): +class AlgorithmProperty(Property): arg_types = {"this": True} -class SqlSecurityProperty(Property): - arg_types = {"definer": True} +class AutoIncrementProperty(Property): + arg_types = {"this": True} -class TableFormatProperty(Property): - arg_types = {"this": True} +class BlockCompressionProperty(Property): + arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True} -class PartitionedByProperty(Property): - arg_types = {"this": True} +class CharacterSetProperty(Property): + arg_types = {"this": True, "default": True} -class FileFormatProperty(Property): - arg_types = {"this": True} +class ChecksumProperty(Property): + arg_types = {"on": False, "default": False} -class DistKeyProperty(Property): +class CollateProperty(Property): arg_types = {"this": True} -class SortKeyProperty(Property): - arg_types = {"this": True, "compound": False} +class DataBlocksizeProperty(Property): + arg_types = {"size": False, "units": False, "min": False, "default": False} -class DistStyleProperty(Property): +class DefinerProperty(Property): arg_types = {"this": True} -class LikeProperty(Property): - arg_types = {"this": True, "expressions": False} +class DistKeyProperty(Property): + arg_types = {"this": True} -class LocationProperty(Property): +class DistStyleProperty(Property): arg_types = {"this": True} @@ -1384,38 +1396,90 @@ class EngineProperty(Property): arg_types = {"this": True} -class AutoIncrementProperty(Property): +class ExecuteAsProperty(Property): arg_types = {"this": True} -class CharacterSetProperty(Property): - arg_types = {"this": True, "default": True} +class ExternalProperty(Property): + arg_types = {"this": False} -class CollateProperty(Property): - arg_types = {"this": True} +class FallbackProperty(Property): + arg_types = {"no": True, "protection": False} -class SchemaCommentProperty(Property): +class FileFormatProperty(Property): arg_types = {"this": True} -class ReturnsProperty(Property): - arg_types = {"this": True, "is_table": False, "table": False} +class FreespaceProperty(Property): + arg_types = {"this": True, "percent": False} + + +class IsolatedLoadingProperty(Property): + arg_types = { + "no": True, + "concurrent": True, + "for_all": True, + "for_insert": True, + "for_none": True, + } + + +class JournalProperty(Property): + arg_types = {"no": True, "dual": False, "before": False} class LanguageProperty(Property): arg_types = {"this": True} -class ExecuteAsProperty(Property): +class LikeProperty(Property): + arg_types = {"this": True, "expressions": False} + + +class LocationProperty(Property): arg_types = {"this": True} -class VolatilityProperty(Property): +class LockingProperty(Property): + arg_types = { + "this": False, + "kind": True, + "for_or_in": True, + "lock_type": True, + "override": False, + } + + +class LogProperty(Property): + arg_types = {"no": True} + + +class MaterializedProperty(Property): + arg_types = {"this": False} + + +class MergeBlockRatioProperty(Property): + arg_types = {"this": False, "no": False, "default": False, "percent": False} + + +class NoPrimaryIndexProperty(Property): + arg_types = {"this": False} + + +class OnCommitProperty(Property): + arg_type = {"this": False} + + +class PartitionedByProperty(Property): arg_types = {"this": True} +class ReturnsProperty(Property): + arg_types = {"this": True, "is_table": False, "table": False} + + class RowFormatDelimitedProperty(Property): # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml arg_types = { @@ -1433,68 +1497,48 @@ class RowFormatSerdeProperty(Property): arg_types = {"this": True} -class SerdeProperties(Property): - arg_types = {"expressions": True} - - -class FallbackProperty(Property): - arg_types = {"no": True, "protection": False} - - -class WithJournalTableProperty(Property): +class SchemaCommentProperty(Property): arg_types = {"this": True} -class LogProperty(Property): - arg_types = {"no": True} +class SerdeProperties(Property): + arg_types = {"expressions": True} -class JournalProperty(Property): - arg_types = {"no": True, "dual": False, "before": False} +class SetProperty(Property): + arg_types = {"multi": True} -class AfterJournalProperty(Property): - arg_types = {"no": True, "dual": False, "local": False} +class SortKeyProperty(Property): + arg_types = {"this": True, "compound": False} -class ChecksumProperty(Property): - arg_types = {"on": False, "default": False} +class SqlSecurityProperty(Property): + arg_types = {"definer": True} -class FreespaceProperty(Property): - arg_types = {"this": True, "percent": False} +class TableFormatProperty(Property): + arg_types = {"this": True} -class MergeBlockRatioProperty(Property): - arg_types = {"this": False, "no": False, "default": False, "percent": False} +class TemporaryProperty(Property): + arg_types = {"global_": True} -class DataBlocksizeProperty(Property): - arg_types = {"size": False, "units": False, "min": False, "default": False} +class TransientProperty(Property): + arg_types = {"this": False} -class BlockCompressionProperty(Property): - arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True} +class VolatilityProperty(Property): + arg_types = {"this": True} -class IsolatedLoadingProperty(Property): - arg_types = { - "no": True, - "concurrent": True, - "for_all": True, - "for_insert": True, - "for_none": True, - } +class WithDataProperty(Property): + arg_types = {"no": True, "statistics": False} -class LockingProperty(Property): - arg_types = { - "this": False, - "kind": True, - "for_or_in": True, - "lock_type": True, - "override": False, - } +class WithJournalTableProperty(Property): + arg_types = {"this": True} class Properties(Expression): @@ -1533,7 +1577,7 @@ class Properties(Expression): # Form: alias selection # create [POST_CREATE] # table a [POST_NAME] - # as [POST_ALIAS] (select * from b) + # as [POST_ALIAS] (select * from b) [POST_EXPRESSION] # index (c) [POST_INDEX] class Location(AutoName): POST_CREATE = auto() @@ -1541,6 +1585,7 @@ class Properties(Expression): POST_SCHEMA = auto() POST_WITH = auto() POST_ALIAS = auto() + POST_EXPRESSION = auto() POST_INDEX = auto() UNSUPPORTED = auto() @@ -1797,6 +1842,10 @@ class Union(Subqueryable): def named_selects(self): return self.this.unnest().named_selects + @property + def is_star(self) -> bool: + return self.this.is_star or self.expression.is_star + @property def selects(self): return self.this.unnest().selects @@ -2424,6 +2473,10 @@ class Select(Subqueryable): def named_selects(self) -> t.List[str]: return [e.output_name for e in self.expressions if e.alias_or_name] + @property + def is_star(self) -> bool: + return any(expression.is_star for expression in self.expressions) + @property def selects(self) -> t.List[Expression]: return self.expressions @@ -2446,6 +2499,10 @@ class Subquery(DerivedTable, Unionable): expression = expression.this return expression + @property + def is_star(self) -> bool: + return self.this.is_star + @property def output_name(self): return self.alias @@ -2478,6 +2535,7 @@ class Tag(Expression): class Pivot(Expression): arg_types = { "this": False, + "alias": False, "expressions": True, "field": True, "unpivot": True, @@ -2603,6 +2661,7 @@ class DataType(Expression): IMAGE = auto() VARIANT = auto() OBJECT = auto() + INET = auto() NULL = auto() UNKNOWN = auto() # Sentinel value, useful for type annotation diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 18ae42a..0a7a81f 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -64,15 +64,22 @@ class Generator: "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") ), exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), - exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", + exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), + exp.ExternalProperty: lambda self, e: "EXTERNAL", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), + exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", + exp.MaterializedProperty: lambda self, e: "MATERIALIZED", + exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", + exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS", exp.ReturnsProperty: lambda self, e: self.naked_property(e), - exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), + exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", + exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", + exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY", + exp.TransientProperty: lambda self, e: "TRANSIENT", exp.VolatilityProperty: lambda self, e: e.name, exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", - exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", - exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", @@ -87,9 +94,6 @@ class Generator: exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", } - # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed - CREATE_TRANSIENT = False - # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True @@ -112,6 +116,7 @@ class Generator: exp.DataType.Type.LONGTEXT: "TEXT", exp.DataType.Type.MEDIUMBLOB: "BLOB", exp.DataType.Type.LONGBLOB: "BLOB", + exp.DataType.Type.INET: "INET", } STAR_MAPPING = { @@ -140,6 +145,7 @@ class Generator: exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA, + exp.ExternalProperty: exp.Properties.Location.POST_CREATE, exp.FallbackProperty: exp.Properties.Location.POST_NAME, exp.FileFormatProperty: exp.Properties.Location.POST_WITH, exp.FreespaceProperty: exp.Properties.Location.POST_NAME, @@ -150,7 +156,10 @@ class Generator: exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, exp.LockingProperty: exp.Properties.Location.POST_ALIAS, exp.LogProperty: exp.Properties.Location.POST_NAME, + exp.MaterializedProperty: exp.Properties.Location.POST_CREATE, exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, + exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, + exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, exp.Property: exp.Properties.Location.POST_WITH, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, @@ -158,10 +167,14 @@ class Generator: exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, + exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.TableFormatProperty: exp.Properties.Location.POST_WITH, + exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, + exp.TransientProperty: exp.Properties.Location.POST_CREATE, exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA, + exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } @@ -537,34 +550,9 @@ class Generator: else: expression_sql = f" AS{expression_sql}" - temporary = " TEMPORARY" if expression.args.get("temporary") else "" - transient = ( - " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" - ) - external = " EXTERNAL" if expression.args.get("external") else "" replace = " OR REPLACE" if expression.args.get("replace") else "" - exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" - materialized = " MATERIALIZED" if expression.args.get("materialized") else "" - set_ = " SET" if expression.args.get("set") else "" - multiset = " MULTISET" if expression.args.get("multiset") else "" - global_temporary = " GLOBAL TEMPORARY" if expression.args.get("global_temporary") else "" - volatile = " VOLATILE" if expression.args.get("volatile") else "" - data = expression.args.get("data") - if data is None: - data = "" - elif data: - data = " WITH DATA" - else: - data = " WITH NO DATA" - statistics = expression.args.get("statistics") - if statistics is None: - statistics = "" - elif statistics: - statistics = " AND STATISTICS" - else: - statistics = " AND NO STATISTICS" - no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else "" + exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" indexes = expression.args.get("indexes") index_sql = "" @@ -605,28 +593,24 @@ class Generator: wrapped=False, ) - modifiers = "".join( - ( - replace, - temporary, - transient, - external, - unique, - materialized, - set_, - multiset, - global_temporary, - volatile, - postcreate_props_sql, + modifiers = "".join((replace, unique, postcreate_props_sql)) + + postexpression_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_EXPRESSION): + postexpression_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION] + ), + sep=" ", + prefix=" ", + wrapped=False, ) - ) + no_schema_binding = ( " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" ) - post_expression_modifiers = "".join((data, statistics, no_primary_index)) - - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -810,6 +794,8 @@ class Generator: properties_locs[exp.Properties.Location.POST_CREATE].append(p) elif p_loc == exp.Properties.Location.POST_ALIAS: properties_locs[exp.Properties.Location.POST_ALIAS].append(p) + elif p_loc == exp.Properties.Location.POST_EXPRESSION: + properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p) elif p_loc == exp.Properties.Location.UNSUPPORTED: self.unsupported(f"Unsupported property {p.key}") @@ -931,6 +917,14 @@ class Generator: override = " OVERRIDE" if expression.args.get("override") else "" return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}" + def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str: + data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA" + statistics = expression.args.get("statistics") + statistics_sql = "" + if statistics is not None: + statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS" + return f"{data_sql}{statistics_sql}" + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") @@ -1003,10 +997,6 @@ class Generator: system_time = expression.args.get("system_time") system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - if alias and pivots: - pivots = f"{pivots}{alias}" - alias = "" - return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" def tablesample_sql(self, expression: exp.TableSample) -> str: @@ -1034,11 +1024,13 @@ class Generator: def pivot_sql(self, expression: exp.Pivot) -> str: this = self.sql(expression, "this") + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" expressions = self.expressions(expression, key="expressions") field = self.sql(expression, "field") - return f"{this} {direction}({expressions} FOR {field})" + return f"{this} {direction}({expressions} FOR {field}){alias}" def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f92f5ac..9f32765 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -144,6 +144,7 @@ class Parser(metaclass=_Parser): TokenType.IMAGE, TokenType.VARIANT, TokenType.OBJECT, + TokenType.INET, *NESTED_TYPE_TOKENS, } @@ -509,73 +510,82 @@ class Parser(metaclass=_Parser): } PROPERTY_PARSERS = { + "AFTER": lambda self: self._parse_afterjournal( + no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" + ), + "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), + "BEFORE": lambda self: self._parse_journal( + no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" + ), + "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "CHARACTER SET": lambda self: self._parse_character_set(), + "CHECKSUM": lambda self: self._parse_checksum(), "CLUSTER BY": lambda self: self.expression( exp.Cluster, expressions=self._parse_csv(self._parse_ordered) ), - "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), - "PARTITION BY": lambda self: self._parse_partitioned_by(), - "PARTITIONED BY": lambda self: self._parse_partitioned_by(), - "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), - "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), - "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty), - "DISTKEY": lambda self: self._parse_distkey(), - "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), - "SORTKEY": lambda self: self._parse_sortkey(), - "LIKE": lambda self: self._parse_create_like(), - "RETURNS": lambda self: self._parse_returns(), - "ROW": lambda self: self._parse_row(), "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), - "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), - "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), - "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), - "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), - "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), + "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), + "DATABLOCKSIZE": lambda self: self._parse_datablocksize( + default=self._prev.text.upper() == "DEFAULT" + ), + "DEFINER": lambda self: self._parse_definer(), "DETERMINISTIC": lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), + "DISTKEY": lambda self: self._parse_distkey(), + "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), + "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), + "EXTERNAL": lambda self: self.expression(exp.ExternalProperty), + "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"), + "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "FREESPACE": lambda self: self._parse_freespace(), + "GLOBAL": lambda self: self._parse_temporary(global_=True), "IMMUTABLE": lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), - "STABLE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("STABLE") - ), - "VOLATILE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") - ), - "WITH": lambda self: self._parse_with_property(), - "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), - "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"), - "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"), - "BEFORE": lambda self: self._parse_journal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), "JOURNAL": lambda self: self._parse_journal( no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" ), - "AFTER": lambda self: self._parse_afterjournal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), + "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), + "LIKE": lambda self: self._parse_create_like(), "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True), - "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False), - "CHECKSUM": lambda self: self._parse_checksum(), - "FREESPACE": lambda self: self._parse_freespace(), + "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), + "LOCK": lambda self: self._parse_locking(), + "LOCKING": lambda self: self._parse_locking(), + "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"), + "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), + "MAX": lambda self: self._parse_datablocksize(), + "MAXIMUM": lambda self: self._parse_datablocksize(), "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio( no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT" ), "MIN": lambda self: self._parse_datablocksize(), "MINIMUM": lambda self: self._parse_datablocksize(), - "MAX": lambda self: self._parse_datablocksize(), - "MAXIMUM": lambda self: self._parse_datablocksize(), - "DATABLOCKSIZE": lambda self: self._parse_datablocksize( - default=self._prev.text.upper() == "DEFAULT" + "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), + "NO": lambda self: self._parse_noprimaryindex(), + "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False), + "ON": lambda self: self._parse_oncommit(), + "PARTITION BY": lambda self: self._parse_partitioned_by(), + "PARTITIONED BY": lambda self: self._parse_partitioned_by(), + "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), + "RETURNS": lambda self: self._parse_returns(), + "ROW": lambda self: self._parse_row(), + "SET": lambda self: self.expression(exp.SetProperty, multi=False), + "SORTKEY": lambda self: self._parse_sortkey(), + "STABLE": lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("STABLE") ), - "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), - "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), - "DEFINER": lambda self: self._parse_definer(), - "LOCK": lambda self: self._parse_locking(), - "LOCKING": lambda self: self._parse_locking(), + "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), + "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), + "TEMPORARY": lambda self: self._parse_temporary(global_=False), + "TRANSIENT": lambda self: self.expression(exp.TransientProperty), + "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), + "VOLATILE": lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + ), + "WITH": lambda self: self._parse_with_property(), } CONSTRAINT_PARSERS = { @@ -979,15 +989,7 @@ class Parser(metaclass=_Parser): replace = self._prev.text.upper() == "REPLACE" or self._match_pair( TokenType.OR, TokenType.REPLACE ) - set_ = self._match(TokenType.SET) # Teradata - multiset = self._match_text_seq("MULTISET") # Teradata - global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata - volatile = self._match(TokenType.VOLATILE) # Teradata - temporary = self._match(TokenType.TEMPORARY) - transient = self._match_text_seq("TRANSIENT") - external = self._match_text_seq("EXTERNAL") unique = self._match(TokenType.UNIQUE) - materialized = self._match(TokenType.MATERIALIZED) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): self._match(TokenType.TABLE) @@ -1005,16 +1007,17 @@ class Parser(metaclass=_Parser): exists = self._parse_exists(not_=True) this = None expression = None - data = None - statistics = None - no_primary_index = None indexes = None no_schema_binding = None begin = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) - properties = self._parse_properties() + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.extend(temp_properties.expressions) + elif temp_properties: + properties = temp_properties self._match(TokenType.ALIAS) begin = self._match(TokenType.BEGIN) @@ -1036,7 +1039,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.COMMA): temp_properties = self._parse_properties(before=True) if properties and temp_properties: - properties.expressions.append(temp_properties.expressions) + properties.expressions.extend(temp_properties.expressions) elif temp_properties: properties = temp_properties @@ -1045,7 +1048,7 @@ class Parser(metaclass=_Parser): # exp.Properties.Location.POST_SCHEMA and POST_WITH temp_properties = self._parse_properties() if properties and temp_properties: - properties.expressions.append(temp_properties.expressions) + properties.expressions.extend(temp_properties.expressions) elif temp_properties: properties = temp_properties @@ -1059,24 +1062,19 @@ class Parser(metaclass=_Parser): ): temp_properties = self._parse_properties() if properties and temp_properties: - properties.expressions.append(temp_properties.expressions) + properties.expressions.extend(temp_properties.expressions) elif temp_properties: properties = temp_properties expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: - if self._match_text_seq("WITH", "DATA"): - data = True - elif self._match_text_seq("WITH", "NO", "DATA"): - data = False - - if self._match_text_seq("AND", "STATISTICS"): - statistics = True - elif self._match_text_seq("AND", "NO", "STATISTICS"): - statistics = False - - no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX") + # exp.Properties.Location.POST_EXPRESSION + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.extend(temp_properties.expressions) + elif temp_properties: + properties = temp_properties indexes = [] while True: @@ -1086,7 +1084,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.PARTITION_BY, advance=False): temp_properties = self._parse_properties() if properties and temp_properties: - properties.expressions.append(temp_properties.expressions) + properties.expressions.extend(temp_properties.expressions) elif temp_properties: properties = temp_properties @@ -1102,22 +1100,11 @@ class Parser(metaclass=_Parser): exp.Create, this=this, kind=create_token.text, + unique=unique, expression=expression, - set=set_, - multiset=multiset, - global_temporary=global_temporary, - volatile=volatile, exists=exists, properties=properties, - temporary=temporary, - transient=transient, - external=external, replace=replace, - unique=unique, - materialized=materialized, - data=data, - statistics=statistics, - no_primary_index=no_primary_index, indexes=indexes, no_schema_binding=no_schema_binding, begin=begin, @@ -1196,15 +1183,21 @@ class Parser(metaclass=_Parser): def _parse_with_property( self, ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: + self._match(TokenType.WITH) if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_property) + if self._match_text_seq("JOURNAL"): + return self._parse_withjournaltable() + + if self._match_text_seq("DATA"): + return self._parse_withdata(no=False) + elif self._match_text_seq("NO", "DATA"): + return self._parse_withdata(no=True) + if not self._next: return None - if self._next.text.upper() == "JOURNAL": - return self._parse_withjournaltable() - return self._parse_withisolatedloading() # https://dev.mysql.com/doc/refman/8.0/en/create-view.html @@ -1221,7 +1214,7 @@ class Parser(metaclass=_Parser): return exp.DefinerProperty(this=f"{user}@{host}") def _parse_withjournaltable(self) -> exp.Expression: - self._match_text_seq("WITH", "JOURNAL", "TABLE") + self._match(TokenType.TABLE) self._match(TokenType.EQ) return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) @@ -1319,7 +1312,6 @@ class Parser(metaclass=_Parser): ) def _parse_withisolatedloading(self) -> exp.Expression: - self._match(TokenType.WITH) no = self._match_text_seq("NO") concurrent = self._match_text_seq("CONCURRENT") self._match_text_seq("ISOLATED", "LOADING") @@ -1397,6 +1389,24 @@ class Parser(metaclass=_Parser): this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) + def _parse_withdata(self, no=False) -> exp.Expression: + if self._match_text_seq("AND", "STATISTICS"): + statistics = True + elif self._match_text_seq("AND", "NO", "STATISTICS"): + statistics = False + else: + statistics = None + + return self.expression(exp.WithDataProperty, no=no, statistics=statistics) + + def _parse_noprimaryindex(self) -> exp.Expression: + self._match_text_seq("PRIMARY", "INDEX") + return exp.NoPrimaryIndexProperty() + + def _parse_oncommit(self) -> exp.Expression: + self._match_text_seq("COMMIT", "PRESERVE", "ROWS") + return exp.OnCommitProperty() + def _parse_distkey(self) -> exp.Expression: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) @@ -1450,6 +1460,10 @@ class Parser(metaclass=_Parser): return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) + def _parse_temporary(self, global_=False) -> exp.Expression: + self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL" + return self.expression(exp.TemporaryProperty, global_=global_) + def _parse_describe(self) -> exp.Expression: kind = self._match_set(self.CREATABLES) and self._prev.text this = self._parse_table() @@ -2042,6 +2056,9 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) + if not this.args.get("pivots"): + this.set("pivots", self._parse_pivots()) + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): this.set( "hints", @@ -2182,7 +2199,12 @@ class Parser(metaclass=_Parser): self._match_r_paren() - return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) + pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) + + if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): + pivot.set("alias", self._parse_table_alias()) + + return pivot def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): @@ -3783,12 +3805,13 @@ class Parser(metaclass=_Parser): return None - def _match_set(self, types): + def _match_set(self, types, advance=True): if not self._curr: return None if self._curr.token_type in types: - self._advance() + if advance: + self._advance() return True return None @@ -3816,9 +3839,10 @@ class Parser(metaclass=_Parser): if expression and self._prev_comments: expression.comments = self._prev_comments - def _match_texts(self, texts): + def _match_texts(self, texts, advance=True): if self._curr and self._curr.text.upper() in texts: - self._advance() + if advance: + self._advance() return True return False diff --git a/sqlglot/serde.py b/sqlglot/serde.py index a47ffdb..c5203a7 100644 --- a/sqlglot/serde.py +++ b/sqlglot/serde.py @@ -32,6 +32,9 @@ def dump(node: Node) -> JSON: obj["type"] = node.type.sql() if node.comments: obj["comments"] = node.comments + if node._meta is not None: + obj["meta"] = node._meta + return obj return node @@ -57,11 +60,9 @@ def load(obj: JSON) -> Node: klass = getattr(module, class_name) expression = klass(**{k: load(v) for k, v in obj["args"].items()}) - type_ = obj.get("type") - if type_: - expression.type = exp.DataType.build(type_) - comments = obj.get("comments") - if comments: - expression.comments = load(comments) + expression.type = obj.get("type") + expression.comments = obj.get("comments") + expression._meta = obj.get("meta") + return expression return obj diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 9b29c12..f3f1a70 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -115,6 +115,7 @@ class TokenType(AutoName): IMAGE = auto() VARIANT = auto() OBJECT = auto() + INET = auto() # keywords ALIAS = auto() @@ -437,16 +438,8 @@ class Tokenizer(metaclass=_Tokenizer): _IDENTIFIER_ESCAPES: t.Set[str] = set() KEYWORDS = { - **{ - f"{key}{postfix}": TokenType.BLOCK_START - for key in ("{%", "{#") - for postfix in ("", "+", "-") - }, - **{ - f"{prefix}{key}": TokenType.BLOCK_END - for key in ("%}", "#}") - for prefix in ("", "+", "-") - }, + **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, + **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")}, "{{+": TokenType.BLOCK_START, "{{-": TokenType.BLOCK_START, "+}}": TokenType.BLOCK_END, @@ -533,6 +526,7 @@ class Tokenizer(metaclass=_Tokenizer): "IGNORE NULLS": TokenType.IGNORE_NULLS, "IN": TokenType.IN, "INDEX": TokenType.INDEX, + "INET": TokenType.INET, "INNER": TokenType.INNER, "INSERT": TokenType.INSERT, "INTERVAL": TokenType.INTERVAL, @@ -701,7 +695,7 @@ class Tokenizer(metaclass=_Tokenizer): "VACUUM": TokenType.COMMAND, } - WHITE_SPACE = { + WHITE_SPACE: t.Dict[str, TokenType] = { " ": TokenType.SPACE, "\t": TokenType.SPACE, "\n": TokenType.BREAK, @@ -723,7 +717,7 @@ class Tokenizer(metaclass=_Tokenizer): NUMERIC_LITERALS: t.Dict[str, str] = {} ENCODE: t.Optional[str] = None - COMMENTS = ["--", ("/*", "*/")] + COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")] KEYWORD_TRIE = None # autofilled IDENTIFIER_CAN_START_WITH_DIGIT = False @@ -778,22 +772,16 @@ class Tokenizer(metaclass=_Tokenizer): self._start = self._current self._advance() - if not self._char: + if self._char is None: break - white_space = self.WHITE_SPACE.get(self._char) # type: ignore - identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore - - if white_space: - if white_space == TokenType.BREAK: - self._col = 1 - self._line += 1 - elif self._char.isdigit(): # type:ignore - self._scan_number() - elif identifier_end: - self._scan_identifier(identifier_end) - else: - self._scan_keywords() + if self._char not in self.WHITE_SPACE: + if self._char.isdigit(): + self._scan_number() + elif self._char in self._IDENTIFIERS: + self._scan_identifier(self._IDENTIFIERS[self._char]) + else: + self._scan_keywords() if until and until(): break @@ -807,13 +795,23 @@ class Tokenizer(metaclass=_Tokenizer): return self.sql[start:end] return "" + def _line_break(self, char: t.Optional[str]) -> bool: + return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore + def _advance(self, i: int = 1) -> None: + if self._line_break(self._char): + self._set_new_line() + self._col += i self._current += i self._end = self._current >= self.size # type: ignore self._char = self.sql[self._current - 1] # type: ignore self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore + def _set_new_line(self) -> None: + self._col = 1 + self._line += 1 + @property def _text(self) -> str: return self.sql[self._start : self._current] @@ -917,7 +915,7 @@ class Tokenizer(metaclass=_Tokenizer): self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore self._advance(comment_end_size - 1) else: - while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore + while not self._end and not self._line_break(self._peek): self._advance() self._comments.append(self._text[comment_start_size:]) # type: ignore @@ -926,6 +924,7 @@ class Tokenizer(metaclass=_Tokenizer): if comment_start_line == self._prev_token_line: self.tokens[-1].comments.extend(self._comments) self._comments = [] + self._prev_token_line = self._line return True -- cgit v1.2.3