summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py452
1 files changed, 255 insertions, 197 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 47774fc..beffb91 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -189,12 +189,12 @@ class Generator:
self._max_text_width = max_text_width
self._comments = comments
- def generate(self, expression):
+ def generate(self, expression: t.Optional[exp.Expression]) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
- expression (Expression): the syntax tree.
+ expression: the syntax tree.
Returns
the SQL string.
@@ -213,23 +213,23 @@ class Generator:
return sql
- def unsupported(self, message):
+ def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
- def sep(self, sep=" "):
+ def sep(self, sep: str = " ") -> str:
return f"{sep.strip()}\n" if self.pretty else sep
- def seg(self, sql, sep=" "):
+ def seg(self, sql: str, sep: str = " ") -> str:
return f"{self.sep(sep)}{sql}"
- def pad_comment(self, comment):
+ def pad_comment(self, comment: str) -> str:
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
- def maybe_comment(self, sql, expression):
+ def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
comments = expression.comments if self._comments else None
if not comments:
@@ -243,7 +243,7 @@ class Generator:
return f"{sql} {comments}"
- def wrap(self, expression):
+ def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
@@ -253,21 +253,28 @@ class Generator:
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
- def no_identify(self, func):
+ def no_identify(self, func: t.Callable[[], str]) -> str:
original = self.identify
self.identify = False
result = func()
self.identify = original
return result
- def normalize_func(self, name):
+ def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper":
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
return name
- def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False):
+ def indent(
+ self,
+ sql: str,
+ level: int = 0,
+ pad: t.Optional[int] = None,
+ skip_first: bool = False,
+ skip_last: bool = False,
+ ) -> str:
if not self.pretty:
return sql
@@ -281,7 +288,12 @@ class Generator:
for i, line in enumerate(lines)
)
- def sql(self, expression, key=None, comment=True):
+ def sql(
+ self,
+ expression: t.Optional[str | exp.Expression],
+ key: t.Optional[str] = None,
+ comment: bool = True,
+ ) -> str:
if not expression:
return ""
@@ -313,12 +325,12 @@ class Generator:
return self.maybe_comment(sql, expression) if self._comments and comment else sql
- def uncache_sql(self, expression):
+ def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
exists_sql = " IF EXISTS" if expression.args.get("exists") else ""
return f"UNCACHE TABLE{exists_sql} {table}"
- def cache_sql(self, expression):
+ def cache_sql(self, expression: exp.Cache) -> str:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
@@ -328,13 +340,13 @@ class Generator:
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
return self.prepend_ctes(expression, sql)
- def characterset_sql(self, expression):
+ def characterset_sql(self, expression: exp.CharacterSet) -> str:
if isinstance(expression.parent, exp.Cast):
return f"CHAR CHARACTER SET {self.sql(expression, 'this')}"
default = "DEFAULT " if expression.args.get("default") else ""
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
- def column_sql(self, expression):
+ def column_sql(self, expression: exp.Column) -> str:
return ".".join(
part
for part in [
@@ -345,7 +357,7 @@ class Generator:
if part
)
- def columndef_sql(self, expression):
+ def columndef_sql(self, expression: exp.ColumnDef) -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
@@ -354,46 +366,52 @@ class Generator:
return f"{column} {kind}"
return f"{column} {kind} {constraints}"
- def columnconstraint_sql(self, expression):
+ def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
kind_sql = self.sql(expression, "kind")
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
- def autoincrementcolumnconstraint_sql(self, _):
+ def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
- def checkcolumnconstraint_sql(self, expression):
+ def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
this = self.sql(expression, "this")
return f"CHECK ({this})"
- def commentcolumnconstraint_sql(self, expression):
+ def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str:
comment = self.sql(expression, "this")
return f"COMMENT {comment}"
- def collatecolumnconstraint_sql(self, expression):
+ def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str:
collate = self.sql(expression, "this")
return f"COLLATE {collate}"
- def defaultcolumnconstraint_sql(self, expression):
+ def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str:
+ encode = self.sql(expression, "this")
+ return f"ENCODE {encode}"
+
+ def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
- def generatedasidentitycolumnconstraint_sql(self, expression):
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
- def notnullcolumnconstraint_sql(self, _):
- return "NOT NULL"
+ def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
+ return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
- def primarykeycolumnconstraint_sql(self, expression):
+ def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
- def uniquecolumnconstraint_sql(self, _):
+ def uniquecolumnconstraint_sql(self, _) -> str:
return "UNIQUE"
- def create_sql(self, expression):
+ def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
expression_sql = self.sql(expression, "expression")
@@ -402,47 +420,58 @@ class Generator:
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
+ external = " EXTERNAL" if expression.args.get("external") else ""
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
- expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
+ modifiers = "".join(
+ (
+ replace,
+ temporary,
+ transient,
+ external,
+ unique,
+ materialized,
+ )
+ )
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
return self.prepend_ctes(expression, expression_sql)
- def describe_sql(self, expression):
+ def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
- def prepend_ctes(self, expression, sql):
+ def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
with_ = self.sql(expression, "with")
if with_:
sql = f"{with_}{self.sep()}{sql}"
return sql
- def with_sql(self, expression):
+ def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
return f"WITH {recursive}{sql}"
- def cte_sql(self, expression):
+ def cte_sql(self, expression: exp.CTE) -> str:
alias = self.sql(expression, "alias")
return f"{alias} AS {self.wrap(expression)}"
- def tablealias_sql(self, expression):
+ def tablealias_sql(self, expression: exp.TableAlias) -> str:
alias = self.sql(expression, "this")
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
return f"{alias}{columns}"
- def bitstring_sql(self, expression):
+ def bitstring_sql(self, expression: exp.BitString) -> str:
return self.sql(expression, "this")
- def hexstring_sql(self, expression):
+ def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
- def datatype_sql(self, expression):
+ def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
@@ -455,13 +484,13 @@ class Generator:
)
return f"{type_sql}{nested}"
- def directory_sql(self, expression):
+ def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
row_format = self.sql(expression, "row_format")
row_format = f" {row_format}" if row_format else ""
return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
- def delete_sql(self, expression):
+ def delete_sql(self, expression: exp.Delete) -> str:
this = self.sql(expression, "this")
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
@@ -472,7 +501,7 @@ class Generator:
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
- def drop_sql(self, expression):
+ def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
@@ -481,46 +510,46 @@ class Generator:
cascade = " CASCADE" if expression.args.get("cascade") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
- def except_sql(self, expression):
+ def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.except_op(expression)),
)
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
- def fetch_sql(self, expression):
+ def fetch_sql(self, expression: exp.Fetch) -> str:
direction = expression.args.get("direction")
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
- def filter_sql(self, expression):
+ def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
where = self.sql(expression, "expression")[1:] # where has a leading space
return f"{this} FILTER({where})"
- def hint_sql(self, expression):
+ def hint_sql(self, expression: exp.Hint) -> str:
if self.sql(expression, "this"):
self.unsupported("Hints are not supported")
return ""
- def index_sql(self, expression):
+ def index_sql(self, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON {table} {columns}"
- def identifier_sql(self, expression):
+ def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
- def partition_sql(self, expression):
+ def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
@@ -529,7 +558,7 @@ class Generator:
)
return f"PARTITION({keys})"
- def properties_sql(self, expression):
+ def properties_sql(self, expression: exp.Properties) -> str:
root_properties = []
with_properties = []
@@ -544,21 +573,21 @@ class Generator:
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
- def root_properties(self, properties):
+ def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
- def properties(self, properties, prefix="", sep=", "):
+ def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
+ return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
- def with_properties(self, properties):
- return self.properties(properties, prefix="WITH")
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("WITH"))
- def property_sql(self, expression):
+ def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"
@@ -569,12 +598,12 @@ class Generator:
return f"{property_name}={self.sql(expression, 'this')}"
- def likeproperty_sql(self, expression):
+ def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
- def insert_sql(self, expression):
+ def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
if isinstance(expression.this, exp.Directory):
@@ -592,19 +621,19 @@ class Generator:
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
return self.prepend_ctes(expression, sql)
- def intersect_sql(self, expression):
+ def intersect_sql(self, expression: exp.Intersect) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.intersect_op(expression)),
)
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
- def introducer_sql(self, expression):
+ def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
- def rowformat_sql(self, expression):
+ def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
escaped = expression.args.get("escaped")
@@ -619,7 +648,7 @@ class Generator:
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
- def table_sql(self, expression, sep=" AS "):
+ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
part
for part in [
@@ -642,7 +671,7 @@ class Generator:
return f"{table}{alias}{laterals}{joins}{pivots}"
- def tablesample_sql(self, expression):
+ def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
@@ -665,7 +694,7 @@ class Generator:
seed = f" SEED ({seed})" if seed else ""
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
- def pivot_sql(self, expression):
+ def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
@@ -673,10 +702,10 @@ class Generator:
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
- def tuple_sql(self, expression):
+ def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
- def update_sql(self, expression):
+ def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
set_sql = self.expressions(expression, flat=True)
from_sql = self.sql(expression, "from")
@@ -684,7 +713,7 @@ class Generator:
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
- def values_sql(self, expression):
+ def values_sql(self, expression: exp.Values) -> str:
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
@@ -694,19 +723,19 @@ class Generator:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
- def var_sql(self, expression):
+ def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
- def into_sql(self, expression):
+ def into_sql(self, expression: exp.Into) -> str:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
- def from_sql(self, expression):
+ def from_sql(self, expression: exp.From) -> str:
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
- def group_sql(self, expression):
+ def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
@@ -718,11 +747,11 @@ class Generator:
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
return f"{group_by}{grouping_sets}{cube}{rollup}"
- def having_sql(self, expression):
+ def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"
- def join_sql(self, expression):
+ def join_sql(self, expression: exp.Join) -> str:
op_sql = self.seg(
" ".join(
op
@@ -753,12 +782,12 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
- def lambda_sql(self, expression, arrow_sep="->"):
+ def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
- def lateral_sql(self, expression):
+ def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
@@ -776,15 +805,15 @@ class Generator:
return f"LATERAL {this}{table}{columns}"
- def limit_sql(self, expression):
+ def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
- def offset_sql(self, expression):
+ def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
- def literal_sql(self, expression):
+ def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
if self._replace_backslash:
@@ -793,7 +822,7 @@ class Generator:
text = f"{self.quote_start}{text}{self.quote_end}"
return text
- def loaddata_sql(self, expression):
+ def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
inpath = f" INPATH {self.sql(expression, 'inpath')}"
overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
@@ -806,27 +835,27 @@ class Generator:
serde = f" SERDE {serde}" if serde else ""
return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
- def null_sql(self, *_):
+ def null_sql(self, *_) -> str:
return "NULL"
- def boolean_sql(self, expression):
+ def boolean_sql(self, expression: exp.Boolean) -> str:
return "TRUE" if expression.this else "FALSE"
- def order_sql(self, expression, flat=False):
+ def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
- return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat)
+ return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
- def cluster_sql(self, expression):
+ def cluster_sql(self, expression: exp.Cluster) -> str:
return self.op_expressions("CLUSTER BY", expression)
- def distribute_sql(self, expression):
+ def distribute_sql(self, expression: exp.Distribute) -> str:
return self.op_expressions("DISTRIBUTE BY", expression)
- def sort_sql(self, expression):
+ def sort_sql(self, expression: exp.Sort) -> str:
return self.op_expressions("SORT BY", expression)
- def ordered_sql(self, expression):
+ def ordered_sql(self, expression: exp.Ordered) -> str:
desc = expression.args.get("desc")
asc = not desc
@@ -857,7 +886,7 @@ class Generator:
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
- def query_modifiers(self, expression, *sqls):
+ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
@@ -876,7 +905,7 @@ class Generator:
sep="",
)
- def select_sql(self, expression):
+ def select_sql(self, expression: exp.Select) -> str:
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@@ -890,36 +919,36 @@ class Generator:
)
return self.prepend_ctes(expression, sql)
- def schema_sql(self, expression):
+ def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
return f"{this}{sql}"
- def star_sql(self, expression):
+ def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
return f"*{except_}{replace}"
- def structkwarg_sql(self, expression):
+ def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
- def parameter_sql(self, expression):
+ def parameter_sql(self, expression: exp.Parameter) -> str:
return f"@{self.sql(expression, 'this')}"
- def sessionparameter_sql(self, expression):
+ def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
- def placeholder_sql(self, expression):
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f":{expression.name}" if expression.name else "?"
- def subquery_sql(self, expression):
+ def subquery_sql(self, expression: exp.Subquery) -> str:
alias = self.sql(expression, "alias")
sql = self.query_modifiers(
@@ -931,22 +960,22 @@ class Generator:
return self.prepend_ctes(expression, sql)
- def qualify_sql(self, expression):
+ def qualify_sql(self, expression: exp.Qualify) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
- def union_sql(self, expression):
+ def union_sql(self, expression: exp.Union) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.union_op(expression)),
)
- def union_op(self, expression):
+ def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
return f"UNION{kind}"
- def unnest_sql(self, expression):
+ def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
if alias and self.unnest_column_only:
@@ -958,11 +987,11 @@ class Generator:
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
return f"UNNEST({args}){ordinality}{alias}"
- def where_sql(self, expression):
+ def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('WHERE')}{self.sep()}{this}"
- def window_sql(self, expression):
+ def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
partition = self.expressions(expression, key="partition_by", flat=True)
@@ -988,7 +1017,7 @@ class Generator:
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
- def window_spec_sql(self, expression):
+ def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
@@ -997,33 +1026,33 @@ class Generator:
)
return f"{kind} BETWEEN {start} AND {end}"
- def withingroup_sql(self, expression):
+ def withingroup_sql(self, expression: exp.WithinGroup) -> str:
this = self.sql(expression, "this")
- expression = self.sql(expression, "expression")[1:] # order has a leading space
- return f"{this} WITHIN GROUP ({expression})"
+ expression_sql = self.sql(expression, "expression")[1:] # order has a leading space
+ return f"{this} WITHIN GROUP ({expression_sql})"
- def between_sql(self, expression):
+ def between_sql(self, expression: exp.Between) -> str:
this = self.sql(expression, "this")
low = self.sql(expression, "low")
high = self.sql(expression, "high")
return f"{this} BETWEEN {low} AND {high}"
- def bracket_sql(self, expression):
+ def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.expressions, self.index_offset)
- expressions = ", ".join(self.sql(e) for e in expressions)
+ expressions_sql = ", ".join(self.sql(e) for e in expressions)
- return f"{self.sql(expression, 'this')}[{expressions}]"
+ return f"{self.sql(expression, 'this')}[{expressions_sql}]"
- def all_sql(self, expression):
+ def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
- def any_sql(self, expression):
+ def any_sql(self, expression: exp.Any) -> str:
return f"ANY {self.wrap(expression)}"
- def exists_sql(self, expression):
+ def exists_sql(self, expression: exp.Exists) -> str:
return f"EXISTS{self.wrap(expression)}"
- def case_sql(self, expression):
+ def case_sql(self, expression: exp.Case) -> str:
this = self.sql(expression, "this")
statements = [f"CASE {this}" if this else "CASE"]
@@ -1043,17 +1072,17 @@ class Generator:
return " ".join(statements)
- def constraint_sql(self, expression):
+ def constraint_sql(self, expression: exp.Constraint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
- def extract_sql(self, expression):
+ def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
- def trim_sql(self, expression):
+ def trim_sql(self, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
@@ -1064,16 +1093,16 @@ class Generator:
else:
return f"TRIM({target})"
- def concat_sql(self, expression):
+ def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
- def check_sql(self, expression):
+ def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
return f"CHECK ({this})"
- def foreignkey_sql(self, expression):
+ def foreignkey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
reference = self.sql(expression, "reference")
reference = f" {reference}" if reference else ""
@@ -1083,16 +1112,16 @@ class Generator:
update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
- def unique_sql(self, expression):
+ def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})"
- def if_sql(self, expression):
+ def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
- def in_sql(self, expression):
+ def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
@@ -1106,24 +1135,24 @@ class Generator:
in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}"
- def in_unnest_op(self, unnest):
+ def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})"
- def interval_sql(self, expression):
+ def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
- def reference_sql(self, expression):
+ def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"REFERENCES {this}({expressions})"
- def anonymous_sql(self, expression):
+ def anonymous_sql(self, expression: exp.Anonymous) -> str:
args = self.format_args(*expression.expressions)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
- def paren_sql(self, expression):
+ def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
sql = self.wrap(expression)
else:
@@ -1132,35 +1161,35 @@ class Generator:
return self.prepend_ctes(expression, sql)
- def neg_sql(self, expression):
+ def neg_sql(self, expression: exp.Neg) -> str:
# This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
- def not_sql(self, expression):
+ def not_sql(self, expression: exp.Not) -> str:
return f"NOT {self.sql(expression, 'this')}"
- def alias_sql(self, expression):
+ def alias_sql(self, expression: exp.Alias) -> str:
to_sql = self.sql(expression, "alias")
to_sql = f" AS {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
- def aliases_sql(self, expression):
+ def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
- def attimezone_sql(self, expression):
+ def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone}"
- def add_sql(self, expression):
+ def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
- def and_sql(self, expression):
+ def and_sql(self, expression: exp.And) -> str:
return self.connector_sql(expression, "AND")
- def connector_sql(self, expression, op):
+ def connector_sql(self, expression: exp.Connector, op: str) -> str:
if not self.pretty:
return self.binary(expression, op)
@@ -1168,53 +1197,53 @@ class Generator:
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
return f"{sep}{op} ".join(sqls)
- def bitwiseand_sql(self, expression):
+ def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
return self.binary(expression, "&")
- def bitwiseleftshift_sql(self, expression):
+ def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str:
return self.binary(expression, "<<")
- def bitwisenot_sql(self, expression):
+ def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str:
return f"~{self.sql(expression, 'this')}"
- def bitwiseor_sql(self, expression):
+ def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str:
return self.binary(expression, "|")
- def bitwiserightshift_sql(self, expression):
+ def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str:
return self.binary(expression, ">>")
- def bitwisexor_sql(self, expression):
+ def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
return self.binary(expression, "^")
- def cast_sql(self, expression):
+ def cast_sql(self, expression: exp.Cast) -> str:
return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- def currentdate_sql(self, expression):
+ def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
- def collate_sql(self, expression):
+ def collate_sql(self, expression: exp.Collate) -> str:
return self.binary(expression, "COLLATE")
- def command_sql(self, expression):
+ def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
- def transaction_sql(self, *_):
+ def transaction_sql(self, *_) -> str:
return "BEGIN"
- def commit_sql(self, expression):
+ def commit_sql(self, expression: exp.Commit) -> str:
chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
- def rollback_sql(self, expression):
+ def rollback_sql(self, expression: exp.Rollback) -> str:
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
- def distinct_sql(self, expression):
+ def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1222,13 +1251,13 @@ class Generator:
on = f" ON {on}" if on else ""
return f"DISTINCT{this}{on}"
- def ignorenulls_sql(self, expression):
+ def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
return f"{self.sql(expression, 'this')} IGNORE NULLS"
- def respectnulls_sql(self, expression):
+ def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
- def intdiv_sql(self, expression):
+ def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this, expression=expression.expression),
@@ -1236,79 +1265,79 @@ class Generator:
)
)
- def dpipe_sql(self, expression):
+ def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
- def div_sql(self, expression):
+ def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
- def distance_sql(self, expression):
+ def distance_sql(self, expression: exp.Distance) -> str:
return self.binary(expression, "<->")
- def dot_sql(self, expression):
+ def dot_sql(self, expression: exp.Dot) -> str:
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
- def eq_sql(self, expression):
+ def eq_sql(self, expression: exp.EQ) -> str:
return self.binary(expression, "=")
- def escape_sql(self, expression):
+ def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
- def gt_sql(self, expression):
+ def gt_sql(self, expression: exp.GT) -> str:
return self.binary(expression, ">")
- def gte_sql(self, expression):
+ def gte_sql(self, expression: exp.GTE) -> str:
return self.binary(expression, ">=")
- def ilike_sql(self, expression):
+ def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
- def is_sql(self, expression):
+ def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
- def like_sql(self, expression):
+ def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
- def similarto_sql(self, expression):
+ def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")
- def lt_sql(self, expression):
+ def lt_sql(self, expression: exp.LT) -> str:
return self.binary(expression, "<")
- def lte_sql(self, expression):
+ def lte_sql(self, expression: exp.LTE) -> str:
return self.binary(expression, "<=")
- def mod_sql(self, expression):
+ def mod_sql(self, expression: exp.Mod) -> str:
return self.binary(expression, "%")
- def mul_sql(self, expression):
+ def mul_sql(self, expression: exp.Mul) -> str:
return self.binary(expression, "*")
- def neq_sql(self, expression):
+ def neq_sql(self, expression: exp.NEQ) -> str:
return self.binary(expression, "<>")
- def nullsafeeq_sql(self, expression):
+ def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str:
return self.binary(expression, "IS NOT DISTINCT FROM")
- def nullsafeneq_sql(self, expression):
+ def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
return self.binary(expression, "IS DISTINCT FROM")
- def or_sql(self, expression):
+ def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
- def sub_sql(self, expression):
+ def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
- def trycast_sql(self, expression):
+ def trycast_sql(self, expression: exp.TryCast) -> str:
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- def use_sql(self, expression):
+ def use_sql(self, expression: exp.Use) -> str:
return f"USE {self.sql(expression, 'this')}"
- def binary(self, expression, op):
+ def binary(self, expression: exp.Binary, op: str) -> str:
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
- def function_fallback_sql(self, expression):
+ def function_fallback_sql(self, expression: exp.Func) -> str:
args = []
for arg_value in expression.args.values():
if isinstance(arg_value, list):
@@ -1319,19 +1348,26 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
- def format_args(self, *args):
- args = tuple(self.sql(arg) for arg in args if arg is not None)
- if self.pretty and self.text_width(args) > self._max_text_width:
- return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True)
- return ", ".join(args)
+ def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
+ arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
+ if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
+ return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
+ return ", ".join(arg_sqls)
- def text_width(self, args):
+ def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
- def format_time(self, expression):
+ def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
- def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
+ def expressions(
+ self,
+ expression: exp.Expression,
+ key: t.Optional[str] = None,
+ flat: bool = False,
+ indent: bool = True,
+ sep: str = ", ",
+ ) -> str:
expressions = expression.args.get(key or "expressions")
if not expressions:
@@ -1359,45 +1395,67 @@ class Generator:
else:
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
- result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
- return self.indent(result_sqls, skip_first=False) if indent else result_sqls
+ result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
+ return self.indent(result_sql, skip_first=False) if indent else result_sql
- def op_expressions(self, op, expression, flat=False):
+ def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
expressions_sql = self.expressions(expression, flat=flat)
if flat:
return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
- def naked_property(self, expression):
+ def naked_property(self, expression: exp.Property) -> str:
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
- def set_operation(self, expression, op):
+ def set_operation(self, expression: exp.Expression, op: str) -> str:
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
- def token_sql(self, token_type):
+ def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name)
- def userdefinedfunction_sql(self, expression):
+ def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})"
- def userdefinedfunctionkwarg_sql(self, expression):
+ def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"
- def joinhint_sql(self, expression):
+ def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
- def kwarg_sql(self, expression):
+ def kwarg_sql(self, expression: exp.Kwarg) -> str:
return self.binary(expression, "=>")
+
+ def when_sql(self, expression: exp.When) -> str:
+ this = self.sql(expression, "this")
+ then_expression = expression.args.get("then")
+ if isinstance(then_expression, exp.Insert):
+ then = f"INSERT {self.sql(then_expression, 'this')}"
+ if "expression" in then_expression.args:
+ then += f" VALUES {self.sql(then_expression, 'expression')}"
+ elif isinstance(then_expression, exp.Update):
+ if isinstance(then_expression.args.get("expressions"), exp.Star):
+ then = f"UPDATE {self.sql(then_expression, 'expressions')}"
+ else:
+ then = f"UPDATE SET {self.expressions(then_expression, flat=True)}"
+ else:
+ then = self.sql(then_expression)
+ return f"WHEN {this} THEN {then}"
+
+ def merge_sql(self, expression: exp.Merge) -> str:
+ this = self.sql(expression, "this")
+ using = f"USING {self.sql(expression, 'using')}"
+ on = f"ON {self.sql(expression, 'on')}"
+ return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}"