From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- sqlglot/generator.py | 167 +++++++++++++++++++++++++++------------------------ 1 file changed, 89 insertions(+), 78 deletions(-) (limited to 'sqlglot/generator.py') diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 793cff0..a445178 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -41,6 +41,8 @@ class Generator: max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 + leading_comma (bool): if the the comma is leading or trailing in select statements + Default: False """ TRANSFORMS = { @@ -108,6 +110,7 @@ class Generator: "_indent", "_replace_backslash", "_escaped_quote_end", + "_leading_comma", ) def __init__( @@ -131,6 +134,7 @@ class Generator: unsupported_level=ErrorLevel.WARN, null_ordering=None, max_unsupported=3, + leading_comma=False, ): import sqlglot @@ -157,6 +161,7 @@ class Generator: self._indent = indent self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end + self._leading_comma = leading_comma def generate(self, expression): """ @@ -178,9 +183,7 @@ class Generator: for msg in self.unsupported_messages: logger.warning(msg) elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError( - concat_errors(self.unsupported_messages, self.max_unsupported) - ) + raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) return sql @@ -197,9 +200,7 @@ class Generator: def wrap(self, expression): this_sql = self.indent( - self.sql(expression) - if isinstance(expression, (exp.Select, exp.Union)) - else self.sql(expression, "this"), + self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), level=1, pad=0, ) @@ -251,9 +252,7 @@ class Generator: return transform if not isinstance(expression, exp.Expression): - raise ValueError( - f"Expected an Expression. Received {type(expression)}: {expression}" - ) + raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") exp_handler_name = f"{expression.key}_sql" if hasattr(self, exp_handler_name): @@ -276,11 +275,7 @@ class Generator: lazy = " LAZY" if expression.args.get("lazy") else "" table = self.sql(expression, "this") options = expression.args.get("options") - options = ( - f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" - if options - else "" - ) + options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else "" sql = self.sql(expression, "expression") sql = f" AS{self.sep()}{sql}" if sql else "" sql = f"CACHE{lazy} TABLE {table}{options}{sql}" @@ -306,9 +301,7 @@ class Generator: def columndef_sql(self, expression): column = self.sql(expression, "this") kind = self.sql(expression, "kind") - constraints = self.expressions( - expression, key="constraints", sep=" ", flat=True - ) + constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) if not constraints: return f"{column} {kind}" @@ -338,6 +331,9 @@ class Generator: default = self.sql(expression, "this") return f"DEFAULT {default}" + def generatedasidentitycolumnconstraint_sql(self, expression): + return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" + def notnullcolumnconstraint_sql(self, _): return "NOT NULL" @@ -384,7 +380,10 @@ class Generator: return f"{alias}{columns}" def bitstring_sql(self, expression): - return f"b'{self.sql(expression, 'this')}'" + return self.sql(expression, "this") + + def hexstring_sql(self, expression): + return self.sql(expression, "this") def datatype_sql(self, expression): type_value = expression.this @@ -452,10 +451,7 @@ class Generator: def partition_sql(self, expression): keys = csv( - *[ - f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] - for k, v in expression.args.get("this") - ] + *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] ) return f"PARTITION({keys})" @@ -470,9 +466,9 @@ class Generator: elif p_class in self.WITH_PROPERTIES: with_properties.append(p) - return self.root_properties( - exp.Properties(expressions=root_properties) - ) + self.with_properties(exp.Properties(expressions=with_properties)) + return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( + exp.Properties(expressions=with_properties) + ) def root_properties(self, properties): if properties.expressions: @@ -508,11 +504,7 @@ class Generator: kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" this = self.sql(expression, "this") exists = " IF EXISTS " if expression.args.get("exists") else " " - partition_sql = ( - self.sql(expression, "partition") - if expression.args.get("partition") - else "" - ) + partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" @@ -531,7 +523,7 @@ class Generator: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" def table_sql(self, expression): - return ".".join( + table = ".".join( part for part in [ self.sql(expression, "catalog"), @@ -541,6 +533,10 @@ class Generator: if part ) + laterals = self.expressions(expression, key="laterals", sep="") + joins = self.expressions(expression, key="joins", sep="") + return f"{table}{laterals}{joins}" + def tablesample_sql(self, expression): if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): this = self.sql(expression.this, "this") @@ -586,11 +582,7 @@ class Generator: def group_sql(self, expression): group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) - grouping_sets = ( - f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" - if grouping_sets - else "" - ) + grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" cube = self.expressions(expression, key="cube", indent=False) cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" rollup = self.expressions(expression, key="rollup", indent=False) @@ -603,7 +595,16 @@ class Generator: def join_sql(self, expression): op_sql = self.seg( - " ".join(op for op in (expression.side, expression.kind, "JOIN") if op) + " ".join( + op + for op in ( + "NATURAL" if expression.args.get("natural") else None, + expression.side, + expression.kind, + "JOIN", + ) + if op + ) ) on_sql = self.sql(expression, "on") using = expression.args.get("using") @@ -630,9 +631,9 @@ class Generator: def lateral_sql(self, expression): this = self.sql(expression, "this") - op_sql = self.seg( - f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" - ) + if isinstance(expression.this, exp.Subquery): + return f"LATERAL{self.sep()}{this}" + op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") alias = expression.args["alias"] table = alias.name table = f" {table}" if table else table @@ -688,21 +689,13 @@ class Generator: sort_order = " DESC" if desc else "" nulls_sort_change = "" - if nulls_first and ( - (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last - ): + if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): nulls_sort_change = " NULLS FIRST" - elif ( - nulls_last - and ((asc and nulls_are_small) or (desc and nulls_are_large)) - and not nulls_are_last - ): + elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: nulls_sort_change = " NULLS LAST" if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported( - "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" - ) + self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") nulls_sort_change = "" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" @@ -798,14 +791,20 @@ class Generator: def window_sql(self, expression): this = self.sql(expression, "this") + partition = self.expressions(expression, key="partition_by", flat=True) partition = f"PARTITION BY {partition}" if partition else "" + order = expression.args.get("order") order_sql = self.order_sql(order, flat=True) if order else "" + partition_sql = partition + " " if partition and order else partition + spec = expression.args.get("spec") spec_sql = " " + self.window_spec_sql(spec) if spec else "" + alias = self.sql(expression, "alias") + if expression.arg_key == "window": this = this = f"{self.seg('WINDOW')} {this} AS" else: @@ -818,13 +817,8 @@ class Generator: def window_spec_sql(self, expression): kind = self.sql(expression, "kind") - start = csv( - self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " - ) - end = ( - csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") - or "CURRENT ROW" - ) + start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") + end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" return f"{kind} BETWEEN {start} AND {end}" def withingroup_sql(self, expression): @@ -879,6 +873,17 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" + def trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + + if trim_type == "LEADING": + return f"LTRIM({target})" + elif trim_type == "TRAILING": + return f"RTRIM({target})" + else: + return f"TRIM({target})" + def check_sql(self, expression): this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -898,9 +903,7 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql( - exp.Case(ifs=[expression], default=expression.args.get("false")) - ) + return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def in_sql(self, expression): query = expression.args.get("query") @@ -917,7 +920,9 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression): - return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}" + unit = self.sql(expression, "unit") + unit = f" {unit}" if unit else "" + return f"INTERVAL {self.sql(expression, 'this')}{unit}" def reference_sql(self, expression): this = self.sql(expression, "this") @@ -925,9 +930,7 @@ class Generator: return f"REFERENCES {this}({expressions})" def anonymous_sql(self, expression): - args = self.indent( - self.expressions(expression, flat=True), skip_first=True, skip_last=True - ) + args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" def paren_sql(self, expression): @@ -1006,6 +1009,9 @@ class Generator: def ignorenulls_sql(self, expression): return f"{self.sql(expression, 'this')} IGNORE NULLS" + def respectnulls_sql(self, expression): + return f"{self.sql(expression, 'this')} RESPECT NULLS" + def intdiv_sql(self, expression): return self.sql( exp.Cast( @@ -1023,6 +1029,9 @@ class Generator: def div_sql(self, expression): return self.binary(expression, "/") + def distance_sql(self, expression): + return self.binary(expression, "<->") + def dot_sql(self, expression): return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" @@ -1047,6 +1056,9 @@ class Generator: def like_sql(self, expression): return self.binary(expression, "LIKE") + def similarto_sql(self, expression): + return self.binary(expression, "SIMILAR TO") + def lt_sql(self, expression): return self.binary(expression, "<") @@ -1069,14 +1081,10 @@ class Generator: return self.binary(expression, "-") def trycast_sql(self, expression): - return ( - f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - ) + return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" def binary(self, expression, op): - return ( - f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" - ) + return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" def function_fallback_sql(self, expression): args = [] @@ -1089,9 +1097,7 @@ class Generator: return f"{self.normalize_func(expression.sql_name())}({args_str})" def format_time(self, expression): - return format_time( - self.sql(expression, "format"), self.time_mapping, self.time_trie - ) + return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): expressions = expression.args.get(key or "expressions") @@ -1102,7 +1108,14 @@ class Generator: if flat: return sep.join(self.sql(e) for e in expressions) - expressions = self.sep(sep).join(self.sql(e) for e in expressions) + sql = (self.sql(e) for e in expressions) + # the only time leading_comma changes the output is if pretty print is enabled + if self._leading_comma and self.pretty: + pad = " " * self.pad + expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql)) + else: + expressions = self.sep(sep).join(sql) + if indent: return self.indent(expressions, skip_first=False) return expressions @@ -1116,9 +1129,7 @@ class Generator: def set_operation(self, expression, op): this = self.sql(expression, "this") op = self.seg(op) - return self.query_modifiers( - expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" - ) + return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") def token_sql(self, token_type): return self.TOKEN_MAPPING.get(token_type, token_type.name) -- cgit v1.2.3