diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 124 |
1 files changed, 95 insertions, 29 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b7e295d..bb7fd71 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2,7 +2,7 @@ import logging from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors -from sqlglot.helper import apply_index_offset, csv, ensure_list +from sqlglot.helper import apply_index_offset, csv from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -43,14 +43,18 @@ class Generator: Default: 3 leading_comma (bool): if the the comma is leading or trailing in select statements Default: False + max_text_width: The max number of characters in a segment before creating new lines in pretty mode. + The default is on the smaller end because the length only represents a segment and not the true + line length. + Default: 80 """ TRANSFORMS = { exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", - exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", - exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", - exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})", + exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", + exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", + exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -111,6 +115,7 @@ class Generator: "_replace_backslash", "_escaped_quote_end", "_leading_comma", + "_max_text_width", ) def __init__( @@ -135,6 +140,7 @@ class Generator: null_ordering=None, max_unsupported=3, leading_comma=False, + max_text_width=80, ): import sqlglot @@ -162,6 +168,7 @@ class Generator: self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end self._leading_comma = leading_comma + self._max_text_width = max_text_width def generate(self, expression): """ @@ -268,7 +275,7 @@ class Generator: raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") def annotation_sql(self, expression): - return self.sql(expression, "expression") + return f"{self.sql(expression, 'expression')} # {expression.name.strip()}" def uncache_sql(self, expression): table = self.sql(expression, "this") @@ -364,6 +371,9 @@ class Generator: ) return self.prepend_ctes(expression, expression_sql) + def describe_sql(self, expression): + return f"DESCRIBE {self.sql(expression, 'this')}" + def prepend_ctes(self, expression, sql): with_ = self.sql(expression, "with") if with_: @@ -405,6 +415,12 @@ class Generator: ) return f"{type_sql}{nested}" + def directory_sql(self, expression): + local = "LOCAL " if expression.args.get("local") else "" + row_format = self.sql(expression, "row_format") + row_format = f" {row_format}" if row_format else "" + return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" + def delete_sql(self, expression): this = self.sql(expression, "this") where_sql = self.sql(expression, "where") @@ -513,13 +529,19 @@ class Generator: return f"{key}={value}" def insert_sql(self, expression): - kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" - this = self.sql(expression, "this") + overwrite = expression.args.get("overwrite") + + if isinstance(expression.this, exp.Directory): + this = "OVERWRITE " if overwrite else "INTO " + else: + this = "OVERWRITE TABLE " if overwrite else "INTO " + + this = f"{this}{self.sql(expression, 'this')}" exists = " IF EXISTS " if expression.args.get("exists") else " " partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" - sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" + sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression): @@ -534,6 +556,21 @@ class Generator: def introducer_sql(self, expression): return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def rowformat_sql(self, expression): + fields = expression.args.get("fields") + fields = f" FIELDS TERMINATED BY {fields}" if fields else "" + escaped = expression.args.get("escaped") + escaped = f" ESCAPED BY {escaped}" if escaped else "" + items = expression.args.get("collection_items") + items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else "" + keys = expression.args.get("map_keys") + keys = f" MAP KEYS TERMINATED BY {keys}" if keys else "" + lines = expression.args.get("lines") + lines = f" LINES TERMINATED BY {lines}" if lines else "" + null = expression.args.get("null") + null = f" NULL DEFINED AS {null}" if null else "" + return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" + def table_sql(self, expression): table = ".".join( part @@ -688,6 +725,19 @@ class Generator: return f"{self.quote_start}{text}{self.quote_end}" return text + def loaddata_sql(self, expression): + local = " LOCAL" if expression.args.get("local") else "" + inpath = f" INPATH {self.sql(expression, 'inpath')}" + overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" + this = f" INTO TABLE {self.sql(expression, 'this')}" + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + input_format = self.sql(expression, "input_format") + input_format = f" INPUTFORMAT {input_format}" if input_format else "" + serde = self.sql(expression, "serde") + serde = f" SERDE {serde}" if serde else "" + return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" + def null_sql(self, *_): return "NULL" @@ -885,20 +935,24 @@ class Generator: return f"EXISTS{self.wrap(expression)}" def case_sql(self, expression): - this = self.indent(self.sql(expression, "this"), skip_first=True) - this = f" {this}" if this else "" - ifs = [] + this = self.sql(expression, "this") + statements = [f"CASE {this}" if this else "CASE"] for e in expression.args["ifs"]: - ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}")) - ifs.append(self.indent(f"THEN {self.sql(e, 'true')}")) + statements.append(f"WHEN {self.sql(e, 'this')}") + statements.append(f"THEN {self.sql(e, 'true')}") + + default = self.sql(expression, "default") + + if default: + statements.append(f"ELSE {default}") - if expression.args.get("default") is not None: - ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}")) + statements.append("END") - ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs) - statement = f"CASE{this}{ifs}{self.seg('END')}" - return statement + if self.pretty and self.text_width(statements) > self._max_text_width: + return self.indent("\n".join(statements), skip_first=True, skip_last=True) + + return " ".join(statements) def constraint_sql(self, expression): this = self.sql(expression, "this") @@ -970,7 +1024,7 @@ class Generator: return f"REFERENCES {this}({expressions})" def anonymous_sql(self, expression): - args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True) + args = self.format_args(*expression.expressions) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" def paren_sql(self, expression): @@ -1008,7 +1062,9 @@ class Generator: if not self.pretty: return self.binary(expression, op) - return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False)) + sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False)) + sep = "\n" if self.text_width(sqls) > self._max_text_width else " " + return f"{sep}{op} ".join(sqls) def bitwiseand_sql(self, expression): return self.binary(expression, "&") @@ -1039,7 +1095,7 @@ class Generator: return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" def distinct_sql(self, expression): - this = self.sql(expression, "this") + this = self.expressions(expression, flat=True) this = f" {this}" if this else "" on = self.sql(expression, "on") @@ -1128,13 +1184,23 @@ class Generator: def function_fallback_sql(self, expression): args = [] - for arg_key in expression.arg_types: - arg_value = ensure_list(expression.args.get(arg_key) or []) - for a in arg_value: - args.append(self.sql(a)) - - args_str = self.indent(", ".join(args), skip_first=True, skip_last=True) - return f"{self.normalize_func(expression.sql_name())}({args_str})" + for arg_value in expression.args.values(): + if isinstance(arg_value, list): + for value in arg_value: + args.append(value) + elif arg_value: + args.append(arg_value) + + return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" + + def format_args(self, *args): + args = tuple(self.sql(arg) for arg in args if arg is not None) + if self.pretty and self.text_width(args) > self._max_text_width: + return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True) + return ", ".join(args) + + def text_width(self, args): + return sum(len(arg) for arg in args) def format_time(self, expression): return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) |