diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 264 |
1 files changed, 148 insertions, 116 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0d72fe3..1479e28 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,19 +1,16 @@ from __future__ import annotations import logging -import re import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages -from sqlglot.helper import apply_index_offset, csv +from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") -BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)") - class Generator: """ @@ -59,10 +56,14 @@ class Generator: """ TRANSFORMS = { - exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", - exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", - exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", - exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", e.this, e.expression, e.args.get("unit") + ), + exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), + exp.TsOrDsAdd: lambda self, e: self.func( + "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") + ), + exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), @@ -72,6 +73,17 @@ class Generator: exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", + exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", + exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", + exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", + exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE", + exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", + exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", + exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", + exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", + exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", + exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", + exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -89,8 +101,8 @@ class Generator: # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True - # Whether or not create function uses an AS before the def. - CREATE_FUNCTION_AS = True + # Whether or not create function uses an AS before the RETURN + CREATE_FUNCTION_RETURN_AS = True TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", @@ -110,42 +122,46 @@ class Generator: STRUCT_DELIMITER = ("<", ">") + PARAMETER_TOKEN = "@" + PROPERTIES_LOCATION = { - exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AfterJournalProperty: exp.Properties.Location.POST_NAME, exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, - exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA, - exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA, - exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA, + exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, + exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, + exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, + exp.ChecksumProperty: exp.Properties.Location.POST_NAME, + exp.CollateProperty: exp.Properties.Location.POST_SCHEMA, + exp.Cluster: exp.Properties.Location.POST_SCHEMA, + exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME, exp.DefinerProperty: exp.Properties.Location.POST_CREATE, - exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA, - exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, - exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA, - exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA, - exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA, - exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.LogProperty: exp.Properties.Location.PRE_SCHEMA, - exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH, - exp.Property: exp.Properties.Location.POST_SCHEMA_WITH, - exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT, + exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA, + exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, + exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, + exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA, + exp.FallbackProperty: exp.Properties.Location.POST_NAME, + exp.FileFormatProperty: exp.Properties.Location.POST_WITH, + exp.FreespaceProperty: exp.Properties.Location.POST_NAME, + exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, + exp.JournalProperty: exp.Properties.Location.POST_NAME, + exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, + exp.LikeProperty: exp.Properties.Location.POST_SCHEMA, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, + exp.LockingProperty: exp.Properties.Location.POST_ALIAS, + exp.LogProperty: exp.Properties.Location.POST_NAME, + exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, + exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.Property: exp.Properties.Location.POST_WITH, + exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, + exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, + exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, + exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, - exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH, - exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT, - exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA, + exp.TableFormatProperty: exp.Properties.Location.POST_WITH, + exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA, + exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) @@ -173,7 +189,6 @@ class Generator: "null_ordering", "max_unsupported", "_indent", - "_replace_backslash", "_escaped_quote_end", "_escaped_identifier_end", "_leading_comma", @@ -230,7 +245,6 @@ class Generator: self.max_unsupported = max_unsupported self.null_ordering = null_ordering self._indent = indent - self._replace_backslash = self.string_escape == "\\" self._escaped_quote_end = self.string_escape + self.quote_end self._escaped_identifier_end = self.identifier_escape + self.identifier_end self._leading_comma = leading_comma @@ -403,12 +417,13 @@ class Generator: def column_sql(self, expression: exp.Column) -> str: return ".".join( - part - for part in [ - self.sql(expression, "db"), - self.sql(expression, "table"), - self.sql(expression, "this"), - ] + self.sql(part) + for part in ( + expression.args.get("catalog"), + expression.args.get("db"), + expression.args.get("table"), + expression.args.get("this"), + ) if part ) @@ -430,26 +445,6 @@ class Generator: def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) - def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: - this = self.sql(expression, "this") - return f"CHECK ({this})" - - def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str: - comment = self.sql(expression, "this") - return f"COMMENT {comment}" - - def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str: - collate = self.sql(expression, "this") - return f"COLLATE {collate}" - - def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str: - encode = self.sql(expression, "this") - return f"ENCODE {encode}" - - def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str: - default = self.sql(expression, "this") - return f"DEFAULT {default}" - def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: @@ -459,10 +454,19 @@ class Generator: start = expression.args.get("start") start = f"START WITH {start}" if start else "" increment = expression.args.get("increment") - increment = f"INCREMENT BY {increment}" if increment else "" + increment = f" INCREMENT BY {increment}" if increment else "" + minvalue = expression.args.get("minvalue") + minvalue = f" MINVALUE {minvalue}" if minvalue else "" + maxvalue = expression.args.get("maxvalue") + maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" + cycle = expression.args.get("cycle") + cycle_sql = "" + if cycle is not None: + cycle_sql = f"{' NO' if not cycle else ''} CYCLE" + cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql sequence_opts = "" - if start or increment: - sequence_opts = f"{start} {increment}" + if start or increment or cycle_sql: + sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" sequence_opts = f" ({sequence_opts.strip()})" return f"GENERATED{this}AS IDENTITY{sequence_opts}" @@ -483,22 +487,22 @@ class Generator: properties = expression.args.get("properties") properties_exp = expression.copy() properties_locs = self.locate_properties(properties) if properties else {} - if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get( - exp.Properties.Location.POST_SCHEMA_WITH + if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get( + exp.Properties.Location.POST_WITH ): properties_exp.set( "properties", exp.Properties( expressions=[ - *properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT], - *properties_locs[exp.Properties.Location.POST_SCHEMA_WITH], + *properties_locs[exp.Properties.Location.POST_SCHEMA], + *properties_locs[exp.Properties.Location.POST_WITH], ] ), ) - if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA): + if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME): this_name = self.sql(expression.this, "this") this_properties = self.properties( - exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]), + exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]), wrapped=False, ) this_schema = f"({self.expressions(expression.this)})" @@ -512,8 +516,17 @@ class Generator: if expression_sql: expression_sql = f"{begin}{self.sep()}{expression_sql}" - if self.CREATE_FUNCTION_AS or kind != "FUNCTION": - expression_sql = f" AS{expression_sql}" + if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return): + if properties_locs.get(exp.Properties.Location.POST_ALIAS): + postalias_props_sql = self.properties( + exp.Properties( + expressions=properties_locs[exp.Properties.Location.POST_ALIAS] + ), + wrapped=False, + ) + expression_sql = f" AS {postalias_props_sql}{expression_sql}" + else: + expression_sql = f" AS{expression_sql}" temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( @@ -736,9 +749,9 @@ class Generator: for p in expression.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.POST_SCHEMA_WITH: + if p_loc == exp.Properties.Location.POST_WITH: with_properties.append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: + elif p_loc == exp.Properties.Location.POST_SCHEMA: root_properties.append(p) return self.root_properties( @@ -776,16 +789,18 @@ class Generator: for p in properties.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.PRE_SCHEMA: - properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p) + if p_loc == exp.Properties.Location.POST_NAME: + properties_locs[exp.Properties.Location.POST_NAME].append(p) elif p_loc == exp.Properties.Location.POST_INDEX: properties_locs[exp.Properties.Location.POST_INDEX].append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT: - properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH: - properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p) + elif p_loc == exp.Properties.Location.POST_SCHEMA: + properties_locs[exp.Properties.Location.POST_SCHEMA].append(p) + elif p_loc == exp.Properties.Location.POST_WITH: + properties_locs[exp.Properties.Location.POST_WITH].append(p) elif p_loc == exp.Properties.Location.POST_CREATE: properties_locs[exp.Properties.Location.POST_CREATE].append(p) + elif p_loc == exp.Properties.Location.POST_ALIAS: + properties_locs[exp.Properties.Location.POST_ALIAS].append(p) elif p_loc == exp.Properties.Location.UNSUPPORTED: self.unsupported(f"Unsupported property {p.key}") @@ -899,6 +914,14 @@ class Generator: for_ = " FOR NONE" return f"WITH{no}{concurrent} ISOLATED LOADING{for_}" + def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: + kind = expression.args.get("kind") + this: str = f" {this}" if expression.this else "" + for_or_in = expression.args.get("for_or_in") + lock_type = expression.args.get("lock_type") + override = " OVERRIDE" if expression.args.get("override") else "" + return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}" + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") @@ -907,14 +930,17 @@ class Generator: else: this = "OVERWRITE TABLE " if overwrite else "INTO " + alternative = expression.args.get("alternative") + alternative = f" OR {alternative} " if alternative else " " this = f"{this}{self.sql(expression, 'this')}" + exists = " IF EXISTS " if expression.args.get("exists") else " " partition_sql = ( self.sql(expression, "partition") if expression.args.get("partition") else "" ) expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" - sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" + sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1046,21 +1072,26 @@ class Generator: f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" ) - cube = expression.args.get("cube") - if cube is True: - cube = self.seg("WITH CUBE") + cube = expression.args.get("cube", []) + if seq_get(cube, 0) is True: + return f"{group_by}{self.seg('WITH CUBE')}" else: - cube = self.expressions(expression, key="cube", indent=False) - cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" + cube_sql = self.expressions(expression, key="cube", indent=False) + cube_sql = f"{self.seg('CUBE')} {self.wrap(cube_sql)}" if cube_sql else "" - rollup = expression.args.get("rollup") - if rollup is True: - rollup = self.seg("WITH ROLLUP") + rollup = expression.args.get("rollup", []) + if seq_get(rollup, 0) is True: + return f"{group_by}{self.seg('WITH ROLLUP')}" else: - rollup = self.expressions(expression, key="rollup", indent=False) - rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" + rollup_sql = self.expressions(expression, key="rollup", indent=False) + rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else "" + + groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",") - return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}" + if expression.args.get("expressions") and groupings: + group_by = f"{group_by}," + + return f"{group_by}{groupings}" def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) @@ -1139,8 +1170,6 @@ class Generator: def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: - if self._replace_backslash: - text = BACKSLASH_RE.sub(r"\\\\", text) text = text.replace(self.quote_end, self._escaped_quote_end) if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) @@ -1291,7 +1320,9 @@ class Generator: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" def parameter_sql(self, expression: exp.Parameter) -> str: - return f"@{self.sql(expression, 'this')}" + this = self.sql(expression, "this") + this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}" + return f"{self.PARAMETER_TOKEN}{this}" def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") @@ -1405,7 +1436,10 @@ class Generator: return f"ALL {self.wrap(expression)}" def any_sql(self, expression: exp.Any) -> str: - return f"ANY {self.wrap(expression)}" + this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subqueryable): + this = self.wrap(this) + return f"ANY {this}" def exists_sql(self, expression: exp.Exists) -> str: return f"EXISTS{self.wrap(expression)}" @@ -1444,11 +1478,11 @@ class Generator: trim_type = self.sql(expression, "position") if trim_type == "LEADING": - return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})" + return self.func("LTRIM", expression.this) elif trim_type == "TRAILING": - return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})" + return self.func("RTRIM", expression.this) else: - return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})" + return self.func("TRIM", expression.this, expression.expression) def concat_sql(self, expression: exp.Concat) -> str: if len(expression.expressions) == 1: @@ -1530,8 +1564,7 @@ class Generator: return f"REFERENCES {this}{expressions}{options}" def anonymous_sql(self, expression: exp.Anonymous) -> str: - args = self.format_args(*expression.expressions) - return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" + return self.func(expression.name, *expression.expressions) def paren_sql(self, expression: exp.Paren) -> str: if isinstance(expression.unnest(), exp.Select): @@ -1792,7 +1825,10 @@ class Generator: else: args.append(arg_value) - return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" + return self.func(expression.sql_name(), *args) + + def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str: + return f"{self.normalize_func(name)}({self.format_args(*args)})" def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) @@ -1848,6 +1884,7 @@ class Generator: return self.indent(result_sql, skip_first=False) if indent else result_sql def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: + flat = flat or isinstance(expression.parent, exp.Properties) expressions_sql = self.expressions(expression, flat=flat) if flat: return f"{op} {expressions_sql}" @@ -1880,11 +1917,6 @@ class Generator: ) return f"{this}{expressions}" - def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: - this = self.sql(expression, "this") - kind = self.sql(expression, "kind") - return f"{this} {kind}" - def joinhint_sql(self, expression: exp.JoinHint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) |