summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py124
1 files changed, 95 insertions, 29 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index b7e295d..bb7fd71 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -2,7 +2,7 @@ import logging
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
-from sqlglot.helper import apply_index_offset, csv, ensure_list
+from sqlglot.helper import apply_index_offset, csv
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@@ -43,14 +43,18 @@ class Generator:
Default: 3
leading_comma (bool): if the the comma is leading or trailing in select statements
Default: False
+ max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
+ The default is on the smaller end because the length only represents a segment and not the true
+ line length.
+ Default: 80
"""
TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
- exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
- exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
- exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})",
+ exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
+ exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
+ exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
+ exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@@ -111,6 +115,7 @@ class Generator:
"_replace_backslash",
"_escaped_quote_end",
"_leading_comma",
+ "_max_text_width",
)
def __init__(
@@ -135,6 +140,7 @@ class Generator:
null_ordering=None,
max_unsupported=3,
leading_comma=False,
+ max_text_width=80,
):
import sqlglot
@@ -162,6 +168,7 @@ class Generator:
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
+ self._max_text_width = max_text_width
def generate(self, expression):
"""
@@ -268,7 +275,7 @@ class Generator:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
- return self.sql(expression, "expression")
+ return f"{self.sql(expression, 'expression')} # {expression.name.strip()}"
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@@ -364,6 +371,9 @@ class Generator:
)
return self.prepend_ctes(expression, expression_sql)
+ def describe_sql(self, expression):
+ return f"DESCRIBE {self.sql(expression, 'this')}"
+
def prepend_ctes(self, expression, sql):
with_ = self.sql(expression, "with")
if with_:
@@ -405,6 +415,12 @@ class Generator:
)
return f"{type_sql}{nested}"
+ def directory_sql(self, expression):
+ local = "LOCAL " if expression.args.get("local") else ""
+ row_format = self.sql(expression, "row_format")
+ row_format = f" {row_format}" if row_format else ""
+ return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
+
def delete_sql(self, expression):
this = self.sql(expression, "this")
where_sql = self.sql(expression, "where")
@@ -513,13 +529,19 @@ class Generator:
return f"{key}={value}"
def insert_sql(self, expression):
- kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
- this = self.sql(expression, "this")
+ overwrite = expression.args.get("overwrite")
+
+ if isinstance(expression.this, exp.Directory):
+ this = "OVERWRITE " if overwrite else "INTO "
+ else:
+ this = "OVERWRITE TABLE " if overwrite else "INTO "
+
+ this = f"{this}{self.sql(expression, 'this')}"
exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
- sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
+ sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression):
@@ -534,6 +556,21 @@ class Generator:
def introducer_sql(self, expression):
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
+ def rowformat_sql(self, expression):
+ fields = expression.args.get("fields")
+ fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
+ escaped = expression.args.get("escaped")
+ escaped = f" ESCAPED BY {escaped}" if escaped else ""
+ items = expression.args.get("collection_items")
+ items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
+ keys = expression.args.get("map_keys")
+ keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
+ lines = expression.args.get("lines")
+ lines = f" LINES TERMINATED BY {lines}" if lines else ""
+ null = expression.args.get("null")
+ null = f" NULL DEFINED AS {null}" if null else ""
+ return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
+
def table_sql(self, expression):
table = ".".join(
part
@@ -688,6 +725,19 @@ class Generator:
return f"{self.quote_start}{text}{self.quote_end}"
return text
+ def loaddata_sql(self, expression):
+ local = " LOCAL" if expression.args.get("local") else ""
+ inpath = f" INPATH {self.sql(expression, 'inpath')}"
+ overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
+ this = f" INTO TABLE {self.sql(expression, 'this')}"
+ partition = self.sql(expression, "partition")
+ partition = f" {partition}" if partition else ""
+ input_format = self.sql(expression, "input_format")
+ input_format = f" INPUTFORMAT {input_format}" if input_format else ""
+ serde = self.sql(expression, "serde")
+ serde = f" SERDE {serde}" if serde else ""
+ return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
+
def null_sql(self, *_):
return "NULL"
@@ -885,20 +935,24 @@ class Generator:
return f"EXISTS{self.wrap(expression)}"
def case_sql(self, expression):
- this = self.indent(self.sql(expression, "this"), skip_first=True)
- this = f" {this}" if this else ""
- ifs = []
+ this = self.sql(expression, "this")
+ statements = [f"CASE {this}" if this else "CASE"]
for e in expression.args["ifs"]:
- ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}"))
- ifs.append(self.indent(f"THEN {self.sql(e, 'true')}"))
+ statements.append(f"WHEN {self.sql(e, 'this')}")
+ statements.append(f"THEN {self.sql(e, 'true')}")
+
+ default = self.sql(expression, "default")
+
+ if default:
+ statements.append(f"ELSE {default}")
- if expression.args.get("default") is not None:
- ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}"))
+ statements.append("END")
- ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs)
- statement = f"CASE{this}{ifs}{self.seg('END')}"
- return statement
+ if self.pretty and self.text_width(statements) > self._max_text_width:
+ return self.indent("\n".join(statements), skip_first=True, skip_last=True)
+
+ return " ".join(statements)
def constraint_sql(self, expression):
this = self.sql(expression, "this")
@@ -970,7 +1024,7 @@ class Generator:
return f"REFERENCES {this}({expressions})"
def anonymous_sql(self, expression):
- args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
+ args = self.format_args(*expression.expressions)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
def paren_sql(self, expression):
@@ -1008,7 +1062,9 @@ class Generator:
if not self.pretty:
return self.binary(expression, op)
- return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False))
+ sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
+ sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
+ return f"{sep}{op} ".join(sqls)
def bitwiseand_sql(self, expression):
return self.binary(expression, "&")
@@ -1039,7 +1095,7 @@ class Generator:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
def distinct_sql(self, expression):
- this = self.sql(expression, "this")
+ this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
on = self.sql(expression, "on")
@@ -1128,13 +1184,23 @@ class Generator:
def function_fallback_sql(self, expression):
args = []
- for arg_key in expression.arg_types:
- arg_value = ensure_list(expression.args.get(arg_key) or [])
- for a in arg_value:
- args.append(self.sql(a))
-
- args_str = self.indent(", ".join(args), skip_first=True, skip_last=True)
- return f"{self.normalize_func(expression.sql_name())}({args_str})"
+ for arg_value in expression.args.values():
+ if isinstance(arg_value, list):
+ for value in arg_value:
+ args.append(value)
+ elif arg_value:
+ args.append(arg_value)
+
+ return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
+
+ def format_args(self, *args):
+ args = tuple(self.sql(arg) for arg in args if arg is not None)
+ if self.pretty and self.text_width(args) > self._max_text_width:
+ return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True)
+ return ", ".join(args)
+
+ def text_width(self, args):
+ return sum(len(arg) for arg in args)
def format_time(self, expression):
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)