summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py167
1 files changed, 89 insertions, 78 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 793cff0..a445178 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -41,6 +41,8 @@ class Generator:
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
+ leading_comma (bool): if the the comma is leading or trailing in select statements
+ Default: False
"""
TRANSFORMS = {
@@ -108,6 +110,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
+ "_leading_comma",
)
def __init__(
@@ -131,6 +134,7 @@ class Generator:
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
max_unsupported=3,
+ leading_comma=False,
):
import sqlglot
@@ -157,6 +161,7 @@ class Generator:
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
+ self._leading_comma = leading_comma
def generate(self, expression):
"""
@@ -178,9 +183,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
- raise UnsupportedError(
- concat_errors(self.unsupported_messages, self.max_unsupported)
- )
+ raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
return sql
@@ -197,9 +200,7 @@ class Generator:
def wrap(self, expression):
this_sql = self.indent(
- self.sql(expression)
- if isinstance(expression, (exp.Select, exp.Union))
- else self.sql(expression, "this"),
+ self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
level=1,
pad=0,
)
@@ -251,9 +252,7 @@ class Generator:
return transform
if not isinstance(expression, exp.Expression):
- raise ValueError(
- f"Expected an Expression. Received {type(expression)}: {expression}"
- )
+ raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
@@ -276,11 +275,7 @@ class Generator:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
- options = (
- f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
- if options
- else ""
- )
+ options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
@@ -306,9 +301,7 @@ class Generator:
def columndef_sql(self, expression):
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
- constraints = self.expressions(
- expression, key="constraints", sep=" ", flat=True
- )
+ constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
if not constraints:
return f"{column} {kind}"
@@ -338,6 +331,9 @@ class Generator:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
+ def generatedasidentitycolumnconstraint_sql(self, expression):
+ return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
+
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
@@ -384,7 +380,10 @@ class Generator:
return f"{alias}{columns}"
def bitstring_sql(self, expression):
- return f"b'{self.sql(expression, 'this')}'"
+ return self.sql(expression, "this")
+
+ def hexstring_sql(self, expression):
+ return self.sql(expression, "this")
def datatype_sql(self, expression):
type_value = expression.this
@@ -452,10 +451,7 @@ class Generator:
def partition_sql(self, expression):
keys = csv(
- *[
- f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
- for k, v in expression.args.get("this")
- ]
+ *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
)
return f"PARTITION({keys})"
@@ -470,9 +466,9 @@ class Generator:
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p)
- return self.root_properties(
- exp.Properties(expressions=root_properties)
- ) + self.with_properties(exp.Properties(expressions=with_properties))
+ return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
+ exp.Properties(expressions=with_properties)
+ )
def root_properties(self, properties):
if properties.expressions:
@@ -508,11 +504,7 @@ class Generator:
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
this = self.sql(expression, "this")
exists = " IF EXISTS " if expression.args.get("exists") else " "
- partition_sql = (
- self.sql(expression, "partition")
- if expression.args.get("partition")
- else ""
- )
+ partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
@@ -531,7 +523,7 @@ class Generator:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def table_sql(self, expression):
- return ".".join(
+ table = ".".join(
part
for part in [
self.sql(expression, "catalog"),
@@ -541,6 +533,10 @@ class Generator:
if part
)
+ laterals = self.expressions(expression, key="laterals", sep="")
+ joins = self.expressions(expression, key="joins", sep="")
+ return f"{table}{laterals}{joins}"
+
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
this = self.sql(expression.this, "this")
@@ -586,11 +582,7 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
- grouping_sets = (
- f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
- if grouping_sets
- else ""
- )
+ grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@@ -603,7 +595,16 @@ class Generator:
def join_sql(self, expression):
op_sql = self.seg(
- " ".join(op for op in (expression.side, expression.kind, "JOIN") if op)
+ " ".join(
+ op
+ for op in (
+ "NATURAL" if expression.args.get("natural") else None,
+ expression.side,
+ expression.kind,
+ "JOIN",
+ )
+ if op
+ )
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@@ -630,9 +631,9 @@ class Generator:
def lateral_sql(self, expression):
this = self.sql(expression, "this")
- op_sql = self.seg(
- f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}"
- )
+ if isinstance(expression.this, exp.Subquery):
+ return f"LATERAL{self.sep()}{this}"
+ op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
@@ -688,21 +689,13 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
- if nulls_first and (
- (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
- ):
+ if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
nulls_sort_change = " NULLS FIRST"
- elif (
- nulls_last
- and ((asc and nulls_are_small) or (desc and nulls_are_large))
- and not nulls_are_last
- ):
+ elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- self.unsupported(
- "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
- )
+ self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@@ -798,14 +791,20 @@ class Generator:
def window_sql(self, expression):
this = self.sql(expression, "this")
+
partition = self.expressions(expression, key="partition_by", flat=True)
partition = f"PARTITION BY {partition}" if partition else ""
+
order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else ""
+
partition_sql = partition + " " if partition and order else partition
+
spec = expression.args.get("spec")
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
+
alias = self.sql(expression, "alias")
+
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
@@ -818,13 +817,8 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
- start = csv(
- self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" "
- )
- end = (
- csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
- or "CURRENT ROW"
- )
+ start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
+ end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@@ -879,6 +873,17 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
+ def trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+
+ if trim_type == "LEADING":
+ return f"LTRIM({target})"
+ elif trim_type == "TRAILING":
+ return f"RTRIM({target})"
+ else:
+ return f"TRIM({target})"
+
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@@ -898,9 +903,7 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
- return self.case_sql(
- exp.Case(ifs=[expression], default=expression.args.get("false"))
- )
+ return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def in_sql(self, expression):
query = expression.args.get("query")
@@ -917,7 +920,9 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression):
- return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}"
+ unit = self.sql(expression, "unit")
+ unit = f" {unit}" if unit else ""
+ return f"INTERVAL {self.sql(expression, 'this')}{unit}"
def reference_sql(self, expression):
this = self.sql(expression, "this")
@@ -925,9 +930,7 @@ class Generator:
return f"REFERENCES {this}({expressions})"
def anonymous_sql(self, expression):
- args = self.indent(
- self.expressions(expression, flat=True), skip_first=True, skip_last=True
- )
+ args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
def paren_sql(self, expression):
@@ -1006,6 +1009,9 @@ class Generator:
def ignorenulls_sql(self, expression):
return f"{self.sql(expression, 'this')} IGNORE NULLS"
+ def respectnulls_sql(self, expression):
+ return f"{self.sql(expression, 'this')} RESPECT NULLS"
+
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
@@ -1023,6 +1029,9 @@ class Generator:
def div_sql(self, expression):
return self.binary(expression, "/")
+ def distance_sql(self, expression):
+ return self.binary(expression, "<->")
+
def dot_sql(self, expression):
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
@@ -1047,6 +1056,9 @@ class Generator:
def like_sql(self, expression):
return self.binary(expression, "LIKE")
+ def similarto_sql(self, expression):
+ return self.binary(expression, "SIMILAR TO")
+
def lt_sql(self, expression):
return self.binary(expression, "<")
@@ -1069,14 +1081,10 @@ class Generator:
return self.binary(expression, "-")
def trycast_sql(self, expression):
- return (
- f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- )
+ return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def binary(self, expression, op):
- return (
- f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
- )
+ return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression):
args = []
@@ -1089,9 +1097,7 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({args_str})"
def format_time(self, expression):
- return format_time(
- self.sql(expression, "format"), self.time_mapping, self.time_trie
- )
+ return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
expressions = expression.args.get(key or "expressions")
@@ -1102,7 +1108,14 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
- expressions = self.sep(sep).join(self.sql(e) for e in expressions)
+ sql = (self.sql(e) for e in expressions)
+ # the only time leading_comma changes the output is if pretty print is enabled
+ if self._leading_comma and self.pretty:
+ pad = " " * self.pad
+ expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
+ else:
+ expressions = self.sep(sep).join(sql)
+
if indent:
return self.indent(expressions, skip_first=False)
return expressions
@@ -1116,9 +1129,7 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
- return self.query_modifiers(
- expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
- )
+ return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)