From c51a9844b869fd7cd69e5cc7658d34f61a865185 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 1 Nov 2023 06:12:42 +0100 Subject: Merging upstream version 19.0.1. Signed-off-by: Daniel Baumann --- sqlglot/generator.py | 125 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 76 insertions(+), 49 deletions(-) (limited to 'sqlglot/generator.py') diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0d6778a..4916cf8 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -230,6 +230,12 @@ class Generator: # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle) DATA_TYPE_SPECIFIERS_ALLOWED = False + # Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed + SUPPORTS_NESTED_CTES = True + + # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs + CTE_RECURSIVE_KEYWORD_REQUIRED = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -304,6 +310,7 @@ class Generator: exp.Order: exp.Properties.Location.POST_SCHEMA, exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA, exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, exp.Property: exp.Properties.Location.POST_WITH, exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA, @@ -407,7 +414,6 @@ class Generator: "unsupported_messages", "_escaped_quote_end", "_escaped_identifier_end", - "_cache", ) def __init__( @@ -447,30 +453,38 @@ class Generator: self._escaped_identifier_end: str = ( self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END ) - self._cache: t.Optional[t.Dict[int, str]] = None - def generate( - self, - expression: t.Optional[exp.Expression], - cache: t.Optional[t.Dict[int, str]] = None, - ) -> str: + def generate(self, expression: exp.Expression, copy: bool = True) -> str: """ Generates the SQL string corresponding to the given syntax tree. Args: expression: The syntax tree. - cache: An optional sql string cache. This leverages the hash of an Expression - which can be slow to compute, so only use it if you set _hash on each node. + copy: Whether or not to copy the expression. The generator performs mutations so + it is safer to copy. Returns: The SQL string corresponding to `expression`. """ - if cache is not None: - self._cache = cache + if copy: + expression = expression.copy() + + # Some dialects only support CTEs at the top level expression, so we need to bubble up nested + # CTEs to that level in order to produce a syntactically valid expression. This transformation + # happens here to minimize code duplication, since many expressions support CTEs. + if ( + not self.SUPPORTS_NESTED_CTES + and isinstance(expression, exp.Expression) + and not expression.parent + and "with" in expression.arg_types + and any(node.parent is not expression for node in expression.find_all(exp.With)) + ): + from sqlglot.transforms import move_ctes_to_top_level + + expression = move_ctes_to_top_level(expression) self.unsupported_messages = [] sql = self.sql(expression).strip() - self._cache = None if self.unsupported_level == ErrorLevel.IGNORE: return sql @@ -595,12 +609,6 @@ class Generator: return self.sql(value) return "" - if self._cache is not None: - expression_id = hash(expression) - - if expression_id in self._cache: - return self._cache[expression_id] - transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): @@ -621,11 +629,7 @@ class Generator: else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - sql = self.maybe_comment(sql, expression) if self.comments and comment else sql - - if self._cache is not None: - self._cache[expression_id] = sql - return sql + return self.maybe_comment(sql, expression) if self.comments and comment else sql def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") @@ -879,7 +883,11 @@ class Generator: def with_sql(self, expression: exp.With) -> str: sql = self.expressions(expression, flat=True) - recursive = "RECURSIVE " if expression.args.get("recursive") else "" + recursive = ( + "RECURSIVE " + if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive") + else "" + ) return f"WITH {recursive}{sql}" @@ -1022,7 +1030,7 @@ class Generator: where = self.sql(expression, "expression").strip() return f"{this} FILTER({where})" - agg = expression.this.copy() + agg = expression.this agg_arg = agg.this cond = expression.expression.this agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy())) @@ -1088,9 +1096,9 @@ class Generator: for p in expression.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] if p_loc == exp.Properties.Location.POST_WITH: - with_properties.append(p.copy()) + with_properties.append(p) elif p_loc == exp.Properties.Location.POST_SCHEMA: - root_properties.append(p.copy()) + root_properties.append(p) return self.root_properties( exp.Properties(expressions=root_properties) @@ -1124,7 +1132,7 @@ class Generator: for p in properties.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] if p_loc != exp.Properties.Location.UNSUPPORTED: - properties_locs[p_loc].append(p.copy()) + properties_locs[p_loc].append(p) else: self.unsupported(f"Unsupported property {p.key}") @@ -1238,6 +1246,29 @@ class Generator: for_ = " FOR NONE" return f"WITH{no}{concurrent} ISOLATED LOADING{for_}" + def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str: + if isinstance(expression.this, list): + return f"IN ({self.expressions(expression, key='this', flat=True)})" + if expression.this: + modulus = self.sql(expression, "this") + remainder = self.sql(expression, "expression") + return f"WITH (MODULUS {modulus}, REMAINDER {remainder})" + + from_expressions = self.expressions(expression, key="from_expressions", flat=True) + to_expressions = self.expressions(expression, key="to_expressions", flat=True) + return f"FROM ({from_expressions}) TO ({to_expressions})" + + def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str: + this = self.sql(expression, "this") + + for_values_or_default = expression.expression + if isinstance(for_values_or_default, exp.PartitionBoundSpec): + for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}" + else: + for_values_or_default = " DEFAULT" + + return f"PARTITION OF {this}{for_values_or_default}" + def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: kind = expression.args.get("kind") this = f" {self.sql(expression, 'this')}" if expression.this else "" @@ -1385,7 +1416,12 @@ class Generator: index = self.sql(expression, "index") index = f" AT {index}" if index else "" - return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}" + ordinality = expression.args.get("ordinality") or "" + if ordinality: + ordinality = f" WITH ORDINALITY{alias}" + alias = "" + + return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1489,7 +1525,6 @@ class Generator: return f"{values} AS {alias}" if alias else values # Converts `VALUES...` expression into a series of select unions. - expression = expression.copy() alias_node = expression.args.get("alias") column_names = alias_node and alias_node.columns @@ -1972,8 +2007,7 @@ class Generator: if self.UNNEST_WITH_ORDINALITY: if alias and isinstance(offset, exp.Expression): - alias = alias.copy() - alias.append("columns", offset.copy()) + alias.append("columns", offset) if alias and self.UNNEST_COLUMN_ONLY: columns = alias.columns @@ -2138,7 +2172,6 @@ class Generator: return f"PRIMARY KEY ({expressions}){options}" def if_sql(self, expression: exp.If) -> str: - expression = expression.copy() return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: @@ -2367,7 +2400,9 @@ class Generator: def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: format_sql = self.sql(expression, "format") format_sql = f" FORMAT {format_sql}" if format_sql else "" - return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})" + to_sql = self.sql(expression, "to") + to_sql = f" {to_sql}" if to_sql else "" + return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})" def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") @@ -2510,7 +2545,7 @@ class Generator: def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( - this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()), + this=exp.Div(this=expression.this, expression=expression.expression), to=exp.DataType(this=exp.DataType.Type.INT), ) ) @@ -2779,7 +2814,6 @@ class Generator: hints = table.args.get("hints") if hints and table.alias and isinstance(hints[0], exp.WithTableHint): # T-SQL syntax is MERGE ... [WITH ()] [[AS] table_alias] - table = table.copy() table_alias = f" AS {self.sql(table.args['alias'].pop())}" this = self.sql(table) @@ -2787,7 +2821,9 @@ class Generator: on = f"ON {self.sql(expression, 'on')}" expressions = self.expressions(expression, sep=" ") - return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}" + return self.prepend_ctes( + expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}" + ) def tochar_sql(self, expression: exp.ToChar) -> str: if expression.args.get("format"): @@ -2896,12 +2932,12 @@ class Generator: case = exp.Case().when( expression.this.is_(exp.null()).not_(copy=False), - expression.args["true"].copy(), + expression.args["true"], copy=False, ) else_cond = expression.args.get("false") if else_cond: - case.else_(else_cond.copy(), copy=False) + case.else_(else_cond, copy=False) return self.sql(case) @@ -2931,15 +2967,6 @@ class Generator: if not isinstance(expression, exp.Literal): from sqlglot.optimizer.simplify import simplify - expression = simplify(expression.copy()) + expression = simplify(expression) return expression - - -def cached_generator( - cache: t.Optional[t.Dict[int, str]] = None -) -> t.Callable[[exp.Expression], str]: - """Returns a cached generator.""" - cache = {} if cache is None else cache - generator = Generator(normalize=True, identify="safe") - return lambda e: generator.generate(e, cache) -- cgit v1.2.3