diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 187 |
1 files changed, 131 insertions, 56 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 3f3365a..b95e9bc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -67,6 +67,7 @@ class Generator: exp.VolatilityProperty: lambda self, e: e.name, exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", + exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -75,6 +76,9 @@ class Generator: # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True + # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported + LOCKING_READS_SUPPORTED = False + # Always do union distinct or union all EXPLICIT_UNION = False @@ -99,34 +103,42 @@ class Generator: STRUCT_DELIMITER = ("<", ">") - BEFORE_PROPERTIES = { - exp.FallbackProperty, - exp.WithJournalTableProperty, - exp.LogProperty, - exp.JournalProperty, - exp.AfterJournalProperty, - exp.ChecksumProperty, - exp.FreespaceProperty, - exp.MergeBlockRatioProperty, - exp.DataBlocksizeProperty, - exp.BlockCompressionProperty, - exp.IsolatedLoadingProperty, - } - - ROOT_PROPERTIES = { - exp.ReturnsProperty, - exp.LanguageProperty, - exp.DistStyleProperty, - exp.DistKeyProperty, - exp.SortKeyProperty, - exp.LikeProperty, - } - - WITH_PROPERTIES = { - exp.Property, - exp.FileFormatProperty, - exp.PartitionedByProperty, - exp.TableFormatProperty, + PROPERTIES_LOCATION = { + exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, + exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA, + exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA, + exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA, + exp.DefinerProperty: exp.Properties.Location.POST_CREATE, + exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA, + exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA, + exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA, + exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.LogProperty: exp.Properties.Location.PRE_SCHEMA, + exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.Property: exp.Properties.Location.POST_SCHEMA_WITH, + exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, + exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, + exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA, } WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) @@ -284,10 +296,10 @@ class Generator: ) return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - def no_identify(self, func: t.Callable[[], str]) -> str: + def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str: original = self.identify self.identify = False - result = func() + result = func(*args, **kwargs) self.identify = original return result @@ -455,19 +467,33 @@ class Generator: def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() - has_before_properties = expression.args.get("properties") - has_before_properties = ( - has_before_properties.args.get("before") if has_before_properties else None - ) - if kind == "TABLE" and has_before_properties: + properties = expression.args.get("properties") + properties_exp = expression.copy() + properties_locs = self.locate_properties(properties) if properties else {} + if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get( + exp.Properties.Location.POST_SCHEMA_WITH + ): + properties_exp.set( + "properties", + exp.Properties( + expressions=[ + *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT], + *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH], + ] + ), + ) + if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA): this_name = self.sql(expression.this, "this") - this_properties = self.sql(expression, "properties") + this_properties = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]), + wrapped=False, + ) this_schema = f"({self.expressions(expression.this)})" this = f"{this_name}, {this_properties} {this_schema}" - properties = "" + properties_sql = "" else: this = self.sql(expression, "this") - properties = self.sql(expression, "properties") + properties_sql = self.sql(properties_exp, "properties") begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" @@ -514,11 +540,31 @@ class Generator: if index.args.get("columns") else "" ) + if index.args.get("primary") and properties_locs.get( + exp.Properties.Location.POST_INDEX + ): + postindex_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_INDEX] + ), + wrapped=False, + ) + ind_columns = f"{ind_columns} {postindex_props_sql}" + indexes_sql.append( f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" ) index_sql = "".join(indexes_sql) + postcreate_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_CREATE): + postcreate_props_sql = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]), + sep=" ", + prefix=" ", + wrapped=False, + ) + modifiers = "".join( ( replace, @@ -531,6 +577,7 @@ class Generator: multiset, global_temporary, volatile, + postcreate_props_sql, ) ) no_schema_binding = ( @@ -539,7 +586,7 @@ class Generator: post_expression_modifiers = "".join((data, statistics, no_primary_index)) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -665,24 +712,19 @@ class Generator: return f"PARTITION({self.expressions(expression)})" def properties_sql(self, expression: exp.Properties) -> str: - before_properties = [] root_properties = [] with_properties = [] for p in expression.expressions: - p_class = p.__class__ - if p_class in self.BEFORE_PROPERTIES: - before_properties.append(p) - elif p_class in self.WITH_PROPERTIES: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.POST_SCHEMA_WITH: with_properties.append(p) - elif p_class in self.ROOT_PROPERTIES: + elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: root_properties.append(p) - return ( - self.properties(exp.Properties(expressions=before_properties), before=True) - + self.root_properties(exp.Properties(expressions=root_properties)) - + self.with_properties(exp.Properties(expressions=with_properties)) - ) + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: @@ -695,17 +737,41 @@ class Generator: prefix: str = "", sep: str = ", ", suffix: str = "", - before: bool = False, + wrapped: bool = True, ) -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - expressions = expressions if before else self.wrap(expressions) + expressions = self.wrap(expressions) if wrapped else expressions return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" return "" def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("WITH")) + def locate_properties( + self, properties: exp.Properties + ) -> t.Dict[exp.Properties.Location, list[exp.Property]]: + properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = { + key: [] for key in exp.Properties.Location + } + + for p in properties.expressions: + p_loc = self.PROPERTIES_LOCATION[p.__class__] + if p_loc == exp.Properties.Location.PRE_SCHEMA: + properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p) + elif p_loc == exp.Properties.Location.POST_INDEX: + properties_locs[exp.Properties.Location.POST_INDEX].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: + properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH: + properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p) + elif p_loc == exp.Properties.Location.POST_CREATE: + properties_locs[exp.Properties.Location.POST_CREATE].append(p) + elif p_loc == exp.Properties.Location.UNSUPPORTED: + self.unsupported(f"Unsupported property {p.key}") + + return properties_locs + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: @@ -713,7 +779,7 @@ class Generator: property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) if not property_name: - self.unsupported(f"Unsupported property {property_name}") + self.unsupported(f"Unsupported property {expression.key}") return f"{property_name}={self.sql(expression, 'this')}" @@ -975,7 +1041,7 @@ class Generator: rollup = self.expressions(expression, key="rollup", indent=False) rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" - return f"{group_by}{grouping_sets}{cube}{rollup}" + return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}" def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) @@ -1015,7 +1081,7 @@ class Generator: def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args - return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") + return f"{args} {arrow_sep} {self.sql(expression, 'this')}" def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") @@ -1043,6 +1109,14 @@ class Generator: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + def lock_sql(self, expression: exp.Lock) -> str: + if self.LOCKING_READS_SUPPORTED: + lock_type = "UPDATE" if expression.args["update"] else "SHARE" + return self.seg(f"FOR {lock_type}") + + self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") + return "" + def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: @@ -1163,6 +1237,7 @@ class Generator: self.sql(expression, "order"), self.sql(expression, "limit"), self.sql(expression, "offset"), + self.sql(expression, "lock"), sep="", ) @@ -1773,7 +1848,7 @@ class Generator: def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") - expressions = self.no_identify(lambda: self.expressions(expression)) + expressions = self.no_identify(self.expressions, expression) expressions = ( self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" ) |