diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 214 |
1 files changed, 147 insertions, 67 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ca14425..11d9073 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import logging +import re +import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors @@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") +NEWLINE_RE = re.compile("\r\n?|\n") + class Generator: """ @@ -47,8 +53,7 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - annotations: Whether or not to show annotations in the SQL when `pretty` is True. - Annotations can only be shown in pretty mode otherwise they may clobber resulting sql. + comments: Whether or not to preserve comments in the ouput SQL code. Default: True """ @@ -65,14 +70,16 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } - # whether 'CREATE ... TRANSIENT ... TABLE' is allowed - # can override in dialects + # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed CREATE_TRANSIENT = False - # whether or not null ordering is supported in order by + + # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True - # always do union distinct or union all + + # Always do union distinct or union all EXPLICIT_UNION = False - # wrap derived values in parens, usually standard but spark doesn't support it + + # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True TYPE_MAPPING = { @@ -80,7 +87,7 @@ class Generator: exp.DataType.Type.NVARCHAR: "VARCHAR", } - TOKEN_MAPPING = {} + TOKEN_MAPPING: t.Dict[TokenType, str] = {} STRUCT_DELIMITER = ("<", ">") @@ -96,6 +103,8 @@ class Generator: exp.TableFormatProperty, } + WITH_SEPARATED_COMMENTS = (exp.Select,) + __slots__ = ( "time_mapping", "time_trie", @@ -122,7 +131,7 @@ class Generator: "_escaped_quote_end", "_leading_comma", "_max_text_width", - "_annotations", + "_comments", ) def __init__( @@ -148,7 +157,7 @@ class Generator: max_unsupported=3, leading_comma=False, max_text_width=80, - annotations=True, + comments=True, ): import sqlglot @@ -177,7 +186,7 @@ class Generator: self._escaped_quote_end = self.escape + self.quote_end self._leading_comma = leading_comma self._max_text_width = max_text_width - self._annotations = annotations + self._comments = comments def generate(self, expression): """ @@ -204,7 +213,6 @@ class Generator: return sql def unsupported(self, message): - if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) @@ -215,9 +223,31 @@ class Generator: def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" + def maybe_comment(self, sql, expression, single_line=False): + comment = expression.comment if self._comments else None + + if not comment: + return sql + + comment = " " + comment if comment[0].strip() else comment + comment = comment + " " if comment[-1].strip() else comment + + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"/*{comment}*/{self.sep()}{sql}" + + if not self.pretty: + return f"{sql} /*{comment}*/" + + if not NEWLINE_RE.search(comment): + return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + + return f"/*{comment}*/\n{sql}" + def wrap(self, expression): this_sql = self.indent( - self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), + self.sql(expression) + if isinstance(expression, (exp.Select, exp.Union)) + else self.sql(expression, "this"), level=1, pad=0, ) @@ -251,7 +281,7 @@ class Generator: for i, line in enumerate(lines) ) - def sql(self, expression, key=None): + def sql(self, expression, key=None, comment=True): if not expression: return "" @@ -264,29 +294,24 @@ class Generator: transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): - return transform(self, expression) - if transform: - return transform - - if not isinstance(expression, exp.Expression): + sql = transform(self, expression) + elif transform: + sql = transform + elif isinstance(expression, exp.Expression): + exp_handler_name = f"{expression.key}_sql" + + if hasattr(self, exp_handler_name): + sql = getattr(self, exp_handler_name)(expression) + elif isinstance(expression, exp.Func): + sql = self.function_fallback_sql(expression) + elif isinstance(expression, exp.Property): + sql = self.property_sql(expression) + else: + raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") + else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - exp_handler_name = f"{expression.key}_sql" - if hasattr(self, exp_handler_name): - return getattr(self, exp_handler_name)(expression) - - if isinstance(expression, exp.Func): - return self.function_fallback_sql(expression) - - if isinstance(expression, exp.Property): - return self.property_sql(expression) - - raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") - - def annotation_sql(self, expression): - if self._annotations and self.pretty: - return f"{self.sql(expression, 'expression')} # {expression.name}" - return self.sql(expression, "expression") + return self.maybe_comment(sql, expression) if self._comments and comment else sql def uncache_sql(self, expression): table = self.sql(expression, "this") @@ -371,7 +396,9 @@ class Generator: expression_sql = self.sql(expression, "expression") expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" - transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" + transient = ( + " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" + ) replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" @@ -434,7 +461,9 @@ class Generator: def delete_sql(self, expression): this = self.sql(expression, "this") using_sql = ( - f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else "" + f" USING {self.expressions(expression, 'using', sep=', USING ')}" + if expression.args.get("using") + else "" ) where_sql = self.sql(expression, "where") sql = f"DELETE FROM {this}{using_sql}{where_sql}" @@ -481,15 +510,18 @@ class Generator: return f"{this} ON {table} {columns}" def identifier_sql(self, expression): - value = expression.name - value = value.lower() if self.normalize else value + text = expression.name + text = text.lower() if self.normalize else text if expression.args.get("quoted") or self.identify: - return f"{self.identifier_start}{value}{self.identifier_end}" - return value + text = f"{self.identifier_start}{text}{self.identifier_end}" + return text def partition_sql(self, expression): keys = csv( - *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] + *[ + f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name + for prop in expression.this + ] ) return f"PARTITION({keys})" @@ -504,9 +536,9 @@ class Generator: elif p_class in self.ROOT_PROPERTIES: root_properties.append(p) - return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( - exp.Properties(expressions=with_properties) - ) + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) def root_properties(self, properties): if properties.expressions: @@ -551,7 +583,9 @@ class Generator: this = f"{this}{self.sql(expression, 'this')}" exists = " IF EXISTS " if expression.args.get("exists") else " " - partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" + partition_sql = ( + self.sql(expression, "partition") if expression.args.get("partition") else "" + ) expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" @@ -669,7 +703,9 @@ class Generator: def group_sql(self, expression): group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) - grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" + grouping_sets = ( + f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" + ) cube = self.expressions(expression, key="cube", indent=False) cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" rollup = self.expressions(expression, key="rollup", indent=False) @@ -711,10 +747,10 @@ class Generator: this_sql = self.sql(expression, "this") return f"{expression_sql}{op_sql} {this_sql}{on_sql}" - def lambda_sql(self, expression): + def lambda_sql(self, expression, arrow_sep="->"): args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args - return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}") + return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") def lateral_sql(self, expression): this = self.sql(expression, "this") @@ -748,7 +784,7 @@ class Generator: if self._replace_backslash: text = text.replace("\\", "\\\\") text = text.replace(self.quote_end, self._escaped_quote_end) - return f"{self.quote_start}{text}{self.quote_end}" + text = f"{self.quote_start}{text}{self.quote_end}" return text def loaddata_sql(self, expression): @@ -796,13 +832,21 @@ class Generator: sort_order = " DESC" if desc else "" nulls_sort_change = "" - if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): + if nulls_first and ( + (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last + ): nulls_sort_change = " NULLS FIRST" - elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: + elif ( + nulls_last + and ((asc and nulls_are_small) or (desc and nulls_are_large)) + and not nulls_are_last + ): nulls_sort_change = " NULLS LAST" if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") + self.unsupported( + "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" + ) nulls_sort_change = "" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" @@ -835,7 +879,7 @@ class Generator: sql = self.query_modifiers( expression, f"SELECT{hint}{distinct}{expressions}", - self.sql(expression, "from"), + self.sql(expression, "from", comment=False), ) return self.prepend_ctes(expression, sql) @@ -858,6 +902,13 @@ class Generator: def parameter_sql(self, expression): return f"@{self.sql(expression, 'this')}" + def sessionparameter_sql(self, expression): + this = self.sql(expression, "this") + kind = expression.text("kind") + if kind: + kind = f"{kind}." + return f"@@{kind}{this}" + def placeholder_sql(self, expression): return f":{expression.name}" if expression.name else "?" @@ -931,7 +982,10 @@ class Generator: def window_spec_sql(self, expression): kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") - end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" + end = ( + csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") + or "CURRENT ROW" + ) return f"{kind} BETWEEN {start} AND {end}" def withingroup_sql(self, expression): @@ -1020,7 +1074,9 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))) + return self.case_sql( + exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) + ) def in_sql(self, expression): query = expression.args.get("query") @@ -1196,6 +1252,12 @@ class Generator: def neq_sql(self, expression): return self.binary(expression, "<>") + def nullsafeeq_sql(self, expression): + return self.binary(expression, "IS NOT DISTINCT FROM") + + def nullsafeneq_sql(self, expression): + return self.binary(expression, "IS DISTINCT FROM") + def or_sql(self, expression): return self.connector_sql(expression, "OR") @@ -1205,6 +1267,9 @@ class Generator: def trycast_sql(self, expression): return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + def use_sql(self, expression): + return f"USE {self.sql(expression, 'this')}" + def binary(self, expression, op): return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" @@ -1240,17 +1305,27 @@ class Generator: if flat: return sep.join(self.sql(e) for e in expressions) - sql = (self.sql(e) for e in expressions) - # the only time leading_comma changes the output is if pretty print is enabled - if self._leading_comma and self.pretty: - pad = " " * self.pad - expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql)) - else: - expressions = self.sep(sep).join(sql) + num_sqls = len(expressions) + + # These are calculated once in case we have the leading_comma / pretty option set, correspondingly + pad = " " * self.pad + stripped_sep = sep.strip() - if indent: - return self.indent(expressions, skip_first=False) - return expressions + result_sqls = [] + for i, e in enumerate(expressions): + sql = self.sql(e, comment=False) + comment = self.maybe_comment("", e, single_line=True) + + if self.pretty: + if self._leading_comma: + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + else: + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + else: + result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + + result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) + return self.indent(result_sqls, skip_first=False) if indent else result_sqls def op_expressions(self, op, expression, flat=False): expressions_sql = self.expressions(expression, flat=flat) @@ -1264,7 +1339,9 @@ class Generator: def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) - return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") + return self.query_modifiers( + expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" + ) def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) @@ -1283,3 +1360,6 @@ class Generator: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"{this}({expressions})" + + def kwarg_sql(self, expression): + return self.binary(expression, "=>") |