diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 205 |
1 files changed, 152 insertions, 53 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 8a49d55..bd12d54 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -76,11 +76,13 @@ class Generator: exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY", exp.TransientProperty: lambda self, e: "TRANSIENT", - exp.VolatilityProperty: lambda self, e: e.name, + exp.StabilityProperty: lambda self, e: e.name, + exp.VolatileProperty: lambda self, e: "VOLATILE", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", + exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE", exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", @@ -110,8 +112,19 @@ class Generator: # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed MATCHED_BY_SOURCE = True - # Whether or not limit and fetch are supported - # "ALL", "LIMIT", "FETCH" + # Whether or not the INTERVAL expression works only with values like '1 day' + SINGLE_STRING_INTERVAL = False + + # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs + INTERVAL_ALLOWS_PLURAL_FORM = True + + # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI + TABLESAMPLE_WITH_METHOD = True + + # Whether or not to treat the number in TABLESAMPLE (50) as a percentage + TABLESAMPLE_SIZE_IS_PERCENT = False + + # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" TYPE_MAPPING = { @@ -129,6 +142,18 @@ class Generator: "replace": "REPLACE", } + TIME_PART_SINGULARS = { + "microseconds": "microsecond", + "seconds": "second", + "minutes": "minute", + "hours": "hour", + "days": "day", + "weeks": "week", + "months": "month", + "quarters": "quarter", + "years": "year", + } + TOKEN_MAPPING: t.Dict[TokenType, str] = {} STRUCT_DELIMITER = ("<", ">") @@ -168,6 +193,7 @@ class Generator: exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, exp.Property: exp.Properties.Location.POST_WITH, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, @@ -175,15 +201,22 @@ class Generator: exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, + exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, exp.TableFormatProperty: exp.Properties.Location.POST_WITH, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, exp.TransientProperty: exp.Properties.Location.POST_CREATE, - exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } - WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) + JOIN_HINTS = True + TABLE_HINTS = True + + RESERVED_KEYWORDS: t.Set[str] = set() + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) + UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column) + SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" __slots__ = ( @@ -322,10 +355,15 @@ class Generator: comment = comment + " " if comment[-1].strip() else comment return comment - def maybe_comment(self, sql: str, expression: exp.Expression) -> str: - comments = expression.comments if self._comments else None + def maybe_comment( + self, + sql: str, + expression: t.Optional[exp.Expression] = None, + comments: t.Optional[t.List[str]] = None, + ) -> str: + comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore - if not comments: + if not comments or isinstance(expression, exp.Binary): return sql sep = "\n" if self.pretty else " " @@ -621,7 +659,6 @@ class Generator: replace = " OR REPLACE" if expression.args.get("replace") else "" unique = " UNIQUE" if expression.args.get("unique") else "" - volatile = " VOLATILE" if expression.args.get("volatile") else "" postcreate_props_sql = "" if properties_locs.get(exp.Properties.Location.POST_CREATE): @@ -632,7 +669,7 @@ class Generator: wrapped=False, ) - modifiers = "".join((replace, unique, volatile, postcreate_props_sql)) + modifiers = "".join((replace, unique, postcreate_props_sql)) postexpression_props_sql = "" if properties_locs.get(exp.Properties.Location.POST_EXPRESSION): @@ -684,6 +721,9 @@ class Generator: def hexstring_sql(self, expression: exp.HexString) -> str: return self.sql(expression, "this") + def bytestring_sql(self, expression: exp.ByteString) -> str: + return self.sql(expression, "this") + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) @@ -695,9 +735,7 @@ class Generator: nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" if expression.args.get("values") is not None: delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") - values = ( - f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}" - ) + values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}" else: nested = f"({interior})" @@ -713,7 +751,7 @@ class Generator: this = self.sql(expression, "this") this = f" FROM {this}" if this else "" using_sql = ( - f" USING {self.expressions(expression, 'using', sep=', USING ')}" + f" USING {self.expressions(expression, key='using', sep=', USING ')}" if expression.args.get("using") else "" ) @@ -730,7 +768,10 @@ class Generator: materialized = " MATERIALIZED" if expression.args.get("materialized") else "" cascade = " CASCADE" if expression.args.get("cascade") else "" constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" - return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}" + purge = " PURGE" if expression.args.get("purge") else "" + return ( + f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}" + ) def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( @@ -746,7 +787,10 @@ class Generator: direction = f" {direction.upper()}" if direction else "" count = expression.args.get("count") count = f" {count}" if count else "" - return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" + if expression.args.get("percent"): + count = f"{count} PERCENT" + with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY" + return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}" def filter_sql(self, expression: exp.Filter) -> str: this = self.sql(expression, "this") @@ -766,12 +810,24 @@ class Generator: def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name - text = text.lower() if self.normalize and not expression.quoted else text + lower = text.lower() + text = lower if self.normalize and not expression.quoted else text text = text.replace(self.identifier_end, self._escaped_identifier_end) - if expression.quoted or should_identify(text, self.identify): + if ( + expression.quoted + or should_identify(text, self.identify) + or lower in self.RESERVED_KEYWORDS + ): text = f"{self.identifier_start}{text}{self.identifier_end}" return text + def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: + input_format = self.sql(expression, "input_format") + input_format = f"INPUTFORMAT {input_format}" if input_format else "" + output_format = self.sql(expression, "output_format") + output_format = f"OUTPUTFORMAT {output_format}" if output_format else "" + return self.sep().join((input_format, output_format)) + def national_sql(self, expression: exp.National) -> str: return f"N{self.sql(expression, 'this')}" @@ -984,9 +1040,10 @@ class Generator: self.sql(expression, "partition") if expression.args.get("partition") else "" ) expression_sql = self.sql(expression, "expression") + conflict = self.sql(expression, "conflict") returning = self.sql(expression, "returning") sep = self.sep() if partition_sql else "" - sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}" + sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1004,6 +1061,19 @@ class Generator: def pseudotype_sql(self, expression: exp.PseudoType) -> str: return expression.name.upper() + def onconflict_sql(self, expression: exp.OnConflict) -> str: + conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" + constraint = self.sql(expression, "constraint") + if constraint: + constraint = f"ON CONSTRAINT {constraint}" + key = self.expressions(expression, key="key", flat=True) + do = "" if expression.args.get("duplicate") else " DO " + nothing = "NOTHING" if expression.args.get("nothing") else "" + expressions = self.expressions(expression, flat=True) + if expressions: + expressions = f"UPDATE SET {expressions}" + return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}" + def returning_sql(self, expression: exp.Returning) -> str: return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}" @@ -1036,7 +1106,7 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" hints = self.expressions(expression, key="hints", sep=", ", flat=True) - hints = f" WITH ({hints})" if hints else "" + hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else "" laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") @@ -1053,7 +1123,7 @@ class Generator: this = self.sql(expression, "this") alias = "" method = self.sql(expression, "method") - method = f"{method.upper()} " if method else "" + method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else "" numerator = self.sql(expression, "bucket_numerator") denominator = self.sql(expression, "bucket_denominator") field = self.sql(expression, "bucket_field") @@ -1064,6 +1134,8 @@ class Generator: rows = self.sql(expression, "rows") rows = f"{rows} ROWS" if rows else "" size = self.sql(expression, "size") + if size and self.TABLESAMPLE_SIZE_IS_PERCENT: + size = f"{size} PERCENT" seed = self.sql(expression, "seed") seed = f" {seed_prefix} ({seed})" if seed else "" kind = expression.args.get("kind", "TABLESAMPLE") @@ -1154,6 +1226,7 @@ class Generator: "NATURAL" if expression.args.get("natural") else None, expression.side, expression.kind, + expression.hint if self.JOIN_HINTS else None, "JOIN", ) if op @@ -1311,16 +1384,20 @@ class Generator: def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: partition = self.partition_by_sql(expression) order = self.sql(expression, "order") - measures = self.sql(expression, "measures") - measures = self.seg(f"MEASURES {measures}") if measures else "" + measures = self.expressions(expression, key="measures") + measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else "" rows = self.sql(expression, "rows") rows = self.seg(rows) if rows else "" after = self.sql(expression, "after") after = self.seg(after) if after else "" pattern = self.sql(expression, "pattern") pattern = self.seg(f"PATTERN ({pattern})") if pattern else "" - define = self.sql(expression, "define") - define = self.seg(f"DEFINE {define}") if define else "" + definition_sqls = [ + f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}" + for definition in expression.args.get("define", []) + ] + definitions = self.expressions(sqls=definition_sqls) + define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else "" body = "".join( ( partition, @@ -1332,7 +1409,9 @@ class Generator: define, ) ) - return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}" + alias = self.sql(expression, "alias") + alias = f" {alias}" if alias else "" + return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: limit = expression.args.get("limit") @@ -1353,7 +1432,7 @@ class Generator: self.sql(expression, "group"), self.sql(expression, "having"), self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) + self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) if expression.args.get("windows") else "", self.sql(expression, "distribute"), @@ -1471,15 +1550,21 @@ class Generator: partition_sql = partition + " " if partition and order else partition spec = expression.args.get("spec") - spec_sql = " " + self.window_spec_sql(spec) if spec else "" + spec_sql = " " + self.windowspec_sql(spec) if spec else "" alias = self.sql(expression, "alias") - this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}" + over = self.sql(expression, "over") or "OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else over}" + + first = expression.args.get("first") + if first is not None: + first = " FIRST " if first else " LAST " + first = first or "" if not partition and not order and not spec and alias: return f"{this} {alias}" - window_args = alias + partition_sql + order_sql + spec_sql + window_args = alias + first + partition_sql + order_sql + spec_sql return f"{this} ({window_args.strip()})" @@ -1487,7 +1572,7 @@ class Generator: partition = self.expressions(expression, key="partition_by", flat=True) return f"PARTITION BY {partition}" if partition else "" - def window_spec_sql(self, expression: exp.WindowSpec) -> str: + def windowspec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") end = ( @@ -1508,7 +1593,7 @@ class Generator: return f"{this} BETWEEN {low} AND {high}" def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset(expression.expressions, self.index_offset) + expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset) expressions_sql = ", ".join(self.sql(e) for e in expressions) return f"{self.sql(expression, 'this')}[{expressions_sql}]" @@ -1550,6 +1635,11 @@ class Generator: expressions = self.expressions(expression, flat=True) return f"CONSTRAINT {this} {expressions}" + def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str: + order = expression.args.get("order") + order = f" OVER ({self.order_sql(order, flat=True)})" if order else "" + return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}" + def extract_sql(self, expression: exp.Extract) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") @@ -1586,7 +1676,7 @@ class Generator: def primarykey_sql(self, expression: exp.ForeignKey) -> str: expressions = self.expressions(expression, flat=True) - options = self.expressions(expression, "options", flat=True, sep=" ") + options = self.expressions(expression, key="options", flat=True, sep=" ") options = f" {options}" if options else "" return f"PRIMARY KEY ({expressions}){options}" @@ -1644,17 +1734,20 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: - this = expression.args.get("this") - if this: - this = ( - f" {this}" - if isinstance(this, exp.Literal) or isinstance(this, exp.Paren) - else f" ({this})" - ) - else: - this = "" unit = self.sql(expression, "unit") + if not self.INTERVAL_ALLOWS_PLURAL_FORM: + unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit) unit = f" {unit}" if unit else "" + + if self.SINGLE_STRING_INTERVAL: + this = expression.this.name if expression.this else "" + return f"INTERVAL '{this}{unit}'" + + this = self.sql(expression, "this") + if this: + unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES) + this = f" {this}" if unwrapped else f" ({this})" + return f"INTERVAL{this}{unit}" def return_sql(self, expression: exp.Return) -> str: @@ -1664,7 +1757,7 @@ class Generator: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) expressions = f"({expressions})" if expressions else "" - options = self.expressions(expression, "options", flat=True, sep=" ") + options = self.expressions(expression, key="options", flat=True, sep=" ") options = f" {options}" if options else "" return f"REFERENCES {this}{expressions}{options}" @@ -1690,9 +1783,9 @@ class Generator: return f"NOT {self.sql(expression, 'this')}" def alias_sql(self, expression: exp.Alias) -> str: - to_sql = self.sql(expression, "alias") - to_sql = f" AS {to_sql}" if to_sql else "" - return f"{self.sql(expression, 'this')}{to_sql}" + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + return f"{self.sql(expression, 'this')}{alias}" def aliases_sql(self, expression: exp.Aliases) -> str: return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" @@ -1712,7 +1805,11 @@ class Generator: if not self.pretty: return self.binary(expression, op) - sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False)) + sqls = tuple( + self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e) + for i, e in enumerate(expression.flatten(unnest=False)) + ) + sep = "\n" if self.text_width(sqls) > self._max_text_width else " " return f"{sep}{op} ".join(sqls) @@ -1797,13 +1894,13 @@ class Generator: actions = expression.args["actions"] if isinstance(actions[0], exp.ColumnDef): - actions = self.expressions(expression, "actions", prefix="ADD COLUMN ") + actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ") elif isinstance(actions[0], exp.Schema): - actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") + actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Delete): - actions = self.expressions(expression, "actions", flat=True) + actions = self.expressions(expression, key="actions", flat=True) else: - actions = self.expressions(expression, "actions") + actions = self.expressions(expression, key="actions") exists = " IF EXISTS" if expression.args.get("exists") else "" return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" @@ -1935,6 +2032,7 @@ class Generator: return f"USE{kind}{this}" def binary(self, expression: exp.Binary, op: str) -> str: + op = self.maybe_comment(op, comments=expression.comments) return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" def function_fallback_sql(self, expression: exp.Func) -> str: @@ -1965,14 +2063,15 @@ class Generator: def expressions( self, - expression: exp.Expression, + expression: t.Optional[exp.Expression] = None, key: t.Optional[str] = None, + sqls: t.Optional[t.List[str]] = None, flat: bool = False, indent: bool = True, sep: str = ", ", prefix: str = "", ) -> str: - expressions = expression.args.get(key or "expressions") + expressions = expression.args.get(key or "expressions") if expression else sqls if not expressions: return "" |