diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-12-12 15:42:33 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-12-12 15:42:33 +0000 |
commit | 579e404567dfff42e64325a8c79f03ac627ea341 (patch) | |
tree | 12d101aa5d1b70a69132e5cbd3307741c00d097f /sqlglot/generator.py | |
parent | Adding upstream version 10.1.3. (diff) | |
download | sqlglot-579e404567dfff42e64325a8c79f03ac627ea341.tar.xz sqlglot-579e404567dfff42e64325a8c79f03ac627ea341.zip |
Adding upstream version 10.2.6.upstream/10.2.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 452 |
1 files changed, 255 insertions, 197 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 47774fc..beffb91 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -189,12 +189,12 @@ class Generator: self._max_text_width = max_text_width self._comments = comments - def generate(self, expression): + def generate(self, expression: t.Optional[exp.Expression]) -> str: """ Generates a SQL string by interpreting the given syntax tree. Args - expression (Expression): the syntax tree. + expression: the syntax tree. Returns the SQL string. @@ -213,23 +213,23 @@ class Generator: return sql - def unsupported(self, message): + def unsupported(self, message: str) -> None: if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) - def sep(self, sep=" "): + def sep(self, sep: str = " ") -> str: return f"{sep.strip()}\n" if self.pretty else sep - def seg(self, sql, sep=" "): + def seg(self, sql: str, sep: str = " ") -> str: return f"{self.sep(sep)}{sql}" - def pad_comment(self, comment): + def pad_comment(self, comment: str) -> str: comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment return comment - def maybe_comment(self, sql, expression): + def maybe_comment(self, sql: str, expression: exp.Expression) -> str: comments = expression.comments if self._comments else None if not comments: @@ -243,7 +243,7 @@ class Generator: return f"{sql} {comments}" - def wrap(self, expression): + def wrap(self, expression: exp.Expression | str) -> str: this_sql = self.indent( self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) @@ -253,21 +253,28 @@ class Generator: ) return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - def no_identify(self, func): + def no_identify(self, func: t.Callable[[], str]) -> str: original = self.identify self.identify = False result = func() self.identify = original return result - def normalize_func(self, name): + def normalize_func(self, name: str) -> str: if self.normalize_functions == "upper": return name.upper() if self.normalize_functions == "lower": return name.lower() return name - def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False): + def indent( + self, + sql: str, + level: int = 0, + pad: t.Optional[int] = None, + skip_first: bool = False, + skip_last: bool = False, + ) -> str: if not self.pretty: return sql @@ -281,7 +288,12 @@ class Generator: for i, line in enumerate(lines) ) - def sql(self, expression, key=None, comment=True): + def sql( + self, + expression: t.Optional[str | exp.Expression], + key: t.Optional[str] = None, + comment: bool = True, + ) -> str: if not expression: return "" @@ -313,12 +325,12 @@ class Generator: return self.maybe_comment(sql, expression) if self._comments and comment else sql - def uncache_sql(self, expression): + def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") exists_sql = " IF EXISTS" if expression.args.get("exists") else "" return f"UNCACHE TABLE{exists_sql} {table}" - def cache_sql(self, expression): + def cache_sql(self, expression: exp.Cache) -> str: lazy = " LAZY" if expression.args.get("lazy") else "" table = self.sql(expression, "this") options = expression.args.get("options") @@ -328,13 +340,13 @@ class Generator: sql = f"CACHE{lazy} TABLE {table}{options}{sql}" return self.prepend_ctes(expression, sql) - def characterset_sql(self, expression): + def characterset_sql(self, expression: exp.CharacterSet) -> str: if isinstance(expression.parent, exp.Cast): return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" default = "DEFAULT " if expression.args.get("default") else "" return f"{default}CHARACTER SET={self.sql(expression, 'this')}" - def column_sql(self, expression): + def column_sql(self, expression: exp.Column) -> str: return ".".join( part for part in [ @@ -345,7 +357,7 @@ class Generator: if part ) - def columndef_sql(self, expression): + def columndef_sql(self, expression: exp.ColumnDef) -> str: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) @@ -354,46 +366,52 @@ class Generator: return f"{column} {kind}" return f"{column} {kind} {constraints}" - def columnconstraint_sql(self, expression): + def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") kind_sql = self.sql(expression, "kind") return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql - def autoincrementcolumnconstraint_sql(self, _): + def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) - def checkcolumnconstraint_sql(self, expression): + def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: this = self.sql(expression, "this") return f"CHECK ({this})" - def commentcolumnconstraint_sql(self, expression): + def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str: comment = self.sql(expression, "this") return f"COMMENT {comment}" - def collatecolumnconstraint_sql(self, expression): + def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str: collate = self.sql(expression, "this") return f"COLLATE {collate}" - def defaultcolumnconstraint_sql(self, expression): + def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str: + encode = self.sql(expression, "this") + return f"ENCODE {encode}" + + def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str: default = self.sql(expression, "this") return f"DEFAULT {default}" - def generatedasidentitycolumnconstraint_sql(self, expression): + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" - def notnullcolumnconstraint_sql(self, _): - return "NOT NULL" + def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: + return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" - def primarykeycolumnconstraint_sql(self, expression): + def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str: desc = expression.args.get("desc") if desc is not None: return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" return f"PRIMARY KEY" - def uniquecolumnconstraint_sql(self, _): + def uniquecolumnconstraint_sql(self, _) -> str: return "UNIQUE" - def create_sql(self, expression): + def create_sql(self, expression: exp.Create) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() expression_sql = self.sql(expression, "expression") @@ -402,47 +420,58 @@ class Generator: transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" ) + external = " EXTERNAL" if expression.args.get("external") else "" replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") - expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" + modifiers = "".join( + ( + replace, + temporary, + transient, + external, + unique, + materialized, + ) + ) + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}" return self.prepend_ctes(expression, expression_sql) - def describe_sql(self, expression): + def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" - def prepend_ctes(self, expression, sql): + def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: with_ = self.sql(expression, "with") if with_: sql = f"{with_}{self.sep()}{sql}" return sql - def with_sql(self, expression): + def with_sql(self, expression: exp.With) -> str: sql = self.expressions(expression, flat=True) recursive = "RECURSIVE " if expression.args.get("recursive") else "" return f"WITH {recursive}{sql}" - def cte_sql(self, expression): + def cte_sql(self, expression: exp.CTE) -> str: alias = self.sql(expression, "alias") return f"{alias} AS {self.wrap(expression)}" - def tablealias_sql(self, expression): + def tablealias_sql(self, expression: exp.TableAlias) -> str: alias = self.sql(expression, "this") columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" return f"{alias}{columns}" - def bitstring_sql(self, expression): + def bitstring_sql(self, expression: exp.BitString) -> str: return self.sql(expression, "this") - def hexstring_sql(self, expression): + def hexstring_sql(self, expression: exp.HexString) -> str: return self.sql(expression, "this") - def datatype_sql(self, expression): + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) nested = "" @@ -455,13 +484,13 @@ class Generator: ) return f"{type_sql}{nested}" - def directory_sql(self, expression): + def directory_sql(self, expression: exp.Directory) -> str: local = "LOCAL " if expression.args.get("local") else "" row_format = self.sql(expression, "row_format") row_format = f" {row_format}" if row_format else "" return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" - def delete_sql(self, expression): + def delete_sql(self, expression: exp.Delete) -> str: this = self.sql(expression, "this") using_sql = ( f" USING {self.expressions(expression, 'using', sep=', USING ')}" @@ -472,7 +501,7 @@ class Generator: sql = f"DELETE FROM {this}{using_sql}{where_sql}" return self.prepend_ctes(expression, sql) - def drop_sql(self, expression): + def drop_sql(self, expression: exp.Drop) -> str: this = self.sql(expression, "this") kind = expression.args["kind"] exists_sql = " IF EXISTS " if expression.args.get("exists") else " " @@ -481,46 +510,46 @@ class Generator: cascade = " CASCADE" if expression.args.get("cascade") else "" return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}" - def except_sql(self, expression): + def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.except_op(expression)), ) - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}" - def fetch_sql(self, expression): + def fetch_sql(self, expression: exp.Fetch) -> str: direction = expression.args.get("direction") direction = f" {direction.upper()}" if direction else "" count = expression.args.get("count") count = f" {count}" if count else "" return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" - def filter_sql(self, expression): + def filter_sql(self, expression: exp.Filter) -> str: this = self.sql(expression, "this") where = self.sql(expression, "expression")[1:] # where has a leading space return f"{this} FILTER({where})" - def hint_sql(self, expression): + def hint_sql(self, expression: exp.Hint) -> str: if self.sql(expression, "this"): self.unsupported("Hints are not supported") return "" - def index_sql(self, expression): + def index_sql(self, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") columns = self.sql(expression, "columns") return f"{this} ON {table} {columns}" - def identifier_sql(self, expression): + def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name text = text.lower() if self.normalize else text if expression.args.get("quoted") or self.identify: text = f"{self.identifier_start}{text}{self.identifier_end}" return text - def partition_sql(self, expression): + def partition_sql(self, expression: exp.Partition) -> str: keys = csv( *[ f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name @@ -529,7 +558,7 @@ class Generator: ) return f"PARTITION({keys})" - def properties_sql(self, expression): + def properties_sql(self, expression: exp.Properties) -> str: root_properties = [] with_properties = [] @@ -544,21 +573,21 @@ class Generator: exp.Properties(expressions=root_properties) ) + self.with_properties(exp.Properties(expressions=with_properties)) - def root_properties(self, properties): + def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" - def properties(self, properties, prefix="", sep=", "): + def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" + return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}" return "" - def with_properties(self, properties): - return self.properties(properties, prefix="WITH") + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, prefix=self.seg("WITH")) - def property_sql(self, expression): + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: return f"{expression.name}={self.sql(expression, 'value')}" @@ -569,12 +598,12 @@ class Generator: return f"{property_name}={self.sql(expression, 'this')}" - def likeproperty_sql(self, expression): + def likeproperty_sql(self, expression: exp.LikeProperty) -> str: options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) options = f" {options}" if options else "" return f"LIKE {self.sql(expression, 'this')}{options}" - def insert_sql(self, expression): + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") if isinstance(expression.this, exp.Directory): @@ -592,19 +621,19 @@ class Generator: sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) - def intersect_sql(self, expression): + def intersect_sql(self, expression: exp.Intersect) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.intersect_op(expression)), ) - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}" - def introducer_sql(self, expression): + def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def rowformat_sql(self, expression): + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" escaped = expression.args.get("escaped") @@ -619,7 +648,7 @@ class Generator: null = f" NULL DEFINED AS {null}" if null else "" return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" - def table_sql(self, expression, sep=" AS "): + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: table = ".".join( part for part in [ @@ -642,7 +671,7 @@ class Generator: return f"{table}{alias}{laterals}{joins}{pivots}" - def tablesample_sql(self, expression): + def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: this = self.sql(expression.this, "this") alias = f" AS {self.sql(expression.this, 'alias')}" @@ -665,7 +694,7 @@ class Generator: seed = f" SEED ({seed})" if seed else "" return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}" - def pivot_sql(self, expression): + def pivot_sql(self, expression: exp.Pivot) -> str: this = self.sql(expression, "this") unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" @@ -673,10 +702,10 @@ class Generator: field = self.sql(expression, "field") return f"{this} {direction}({expressions} FOR {field})" - def tuple_sql(self, expression): + def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" - def update_sql(self, expression): + def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") set_sql = self.expressions(expression, flat=True) from_sql = self.sql(expression, "from") @@ -684,7 +713,7 @@ class Generator: sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}" return self.prepend_ctes(expression, sql) - def values_sql(self, expression): + def values_sql(self, expression: exp.Values) -> str: alias = self.sql(expression, "alias") args = self.expressions(expression) if not alias: @@ -694,19 +723,19 @@ class Generator: return f"(VALUES{self.seg('')}{args}){alias}" return f"VALUES{self.seg('')}{args}{alias}" - def var_sql(self, expression): + def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") - def into_sql(self, expression): + def into_sql(self, expression: exp.Into) -> str: temporary = " TEMPORARY" if expression.args.get("temporary") else "" unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" - def from_sql(self, expression): + def from_sql(self, expression: exp.From) -> str: expressions = self.expressions(expression, flat=True) return f"{self.seg('FROM')} {expressions}" - def group_sql(self, expression): + def group_sql(self, expression: exp.Group) -> str: group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = ( @@ -718,11 +747,11 @@ class Generator: rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" return f"{group_by}{grouping_sets}{cube}{rollup}" - def having_sql(self, expression): + def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('HAVING')}{self.sep()}{this}" - def join_sql(self, expression): + def join_sql(self, expression: exp.Join) -> str: op_sql = self.seg( " ".join( op @@ -753,12 +782,12 @@ class Generator: this_sql = self.sql(expression, "this") return f"{expression_sql}{op_sql} {this_sql}{on_sql}" - def lambda_sql(self, expression, arrow_sep="->"): + def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") - def lateral_sql(self, expression): + def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") if isinstance(expression.this, exp.Subquery): @@ -776,15 +805,15 @@ class Generator: return f"LATERAL {this}{table}{columns}" - def limit_sql(self, expression): + def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" - def offset_sql(self, expression): + def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" - def literal_sql(self, expression): + def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: if self._replace_backslash: @@ -793,7 +822,7 @@ class Generator: text = f"{self.quote_start}{text}{self.quote_end}" return text - def loaddata_sql(self, expression): + def loaddata_sql(self, expression: exp.LoadData) -> str: local = " LOCAL" if expression.args.get("local") else "" inpath = f" INPATH {self.sql(expression, 'inpath')}" overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" @@ -806,27 +835,27 @@ class Generator: serde = f" SERDE {serde}" if serde else "" return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" - def null_sql(self, *_): + def null_sql(self, *_) -> str: return "NULL" - def boolean_sql(self, expression): + def boolean_sql(self, expression: exp.Boolean) -> str: return "TRUE" if expression.this else "FALSE" - def order_sql(self, expression, flat=False): + def order_sql(self, expression: exp.Order, flat: bool = False) -> str: this = self.sql(expression, "this") this = f"{this} " if this else this - return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) + return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore - def cluster_sql(self, expression): + def cluster_sql(self, expression: exp.Cluster) -> str: return self.op_expressions("CLUSTER BY", expression) - def distribute_sql(self, expression): + def distribute_sql(self, expression: exp.Distribute) -> str: return self.op_expressions("DISTRIBUTE BY", expression) - def sort_sql(self, expression): + def sort_sql(self, expression: exp.Sort) -> str: return self.op_expressions("SORT BY", expression) - def ordered_sql(self, expression): + def ordered_sql(self, expression: exp.Ordered) -> str: desc = expression.args.get("desc") asc = not desc @@ -857,7 +886,7 @@ class Generator: return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" - def query_modifiers(self, expression, *sqls): + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("laterals", [])], @@ -876,7 +905,7 @@ class Generator: sep="", ) - def select_sql(self, expression): + def select_sql(self, expression: exp.Select) -> str: hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" @@ -890,36 +919,36 @@ class Generator: ) return self.prepend_ctes(expression, sql) - def schema_sql(self, expression): + def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") this = f"{this} " if this else "" sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" return f"{this}{sql}" - def star_sql(self, expression): + def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" return f"*{except_}{replace}" - def structkwarg_sql(self, expression): + def structkwarg_sql(self, expression: exp.StructKwarg) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def parameter_sql(self, expression): + def parameter_sql(self, expression: exp.Parameter) -> str: return f"@{self.sql(expression, 'this')}" - def sessionparameter_sql(self, expression): + def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") kind = expression.text("kind") if kind: kind = f"{kind}." return f"@@{kind}{this}" - def placeholder_sql(self, expression): + def placeholder_sql(self, expression: exp.Placeholder) -> str: return f":{expression.name}" if expression.name else "?" - def subquery_sql(self, expression): + def subquery_sql(self, expression: exp.Subquery) -> str: alias = self.sql(expression, "alias") sql = self.query_modifiers( @@ -931,22 +960,22 @@ class Generator: return self.prepend_ctes(expression, sql) - def qualify_sql(self, expression): + def qualify_sql(self, expression: exp.Qualify) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('QUALIFY')}{self.sep()}{this}" - def union_sql(self, expression): + def union_sql(self, expression: exp.Union) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.union_op(expression)), ) - def union_op(self, expression): + def union_op(self, expression: exp.Union) -> str: kind = " DISTINCT" if self.EXPLICIT_UNION else "" kind = kind if expression.args.get("distinct") else " ALL" return f"UNION{kind}" - def unnest_sql(self, expression): + def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) alias = expression.args.get("alias") if alias and self.unnest_column_only: @@ -958,11 +987,11 @@ class Generator: ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" return f"UNNEST({args}){ordinality}{alias}" - def where_sql(self, expression): + def where_sql(self, expression: exp.Where) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('WHERE')}{self.sep()}{this}" - def window_sql(self, expression): + def window_sql(self, expression: exp.Window) -> str: this = self.sql(expression, "this") partition = self.expressions(expression, key="partition_by", flat=True) @@ -988,7 +1017,7 @@ class Generator: return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" - def window_spec_sql(self, expression): + def window_spec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") end = ( @@ -997,33 +1026,33 @@ class Generator: ) return f"{kind} BETWEEN {start} AND {end}" - def withingroup_sql(self, expression): + def withingroup_sql(self, expression: exp.WithinGroup) -> str: this = self.sql(expression, "this") - expression = self.sql(expression, "expression")[1:] # order has a leading space - return f"{this} WITHIN GROUP ({expression})" + expression_sql = self.sql(expression, "expression")[1:] # order has a leading space + return f"{this} WITHIN GROUP ({expression_sql})" - def between_sql(self, expression): + def between_sql(self, expression: exp.Between) -> str: this = self.sql(expression, "this") low = self.sql(expression, "low") high = self.sql(expression, "high") return f"{this} BETWEEN {low} AND {high}" - def bracket_sql(self, expression): + def bracket_sql(self, expression: exp.Bracket) -> str: expressions = apply_index_offset(expression.expressions, self.index_offset) - expressions = ", ".join(self.sql(e) for e in expressions) + expressions_sql = ", ".join(self.sql(e) for e in expressions) - return f"{self.sql(expression, 'this')}[{expressions}]" + return f"{self.sql(expression, 'this')}[{expressions_sql}]" - def all_sql(self, expression): + def all_sql(self, expression: exp.All) -> str: return f"ALL {self.wrap(expression)}" - def any_sql(self, expression): + def any_sql(self, expression: exp.Any) -> str: return f"ANY {self.wrap(expression)}" - def exists_sql(self, expression): + def exists_sql(self, expression: exp.Exists) -> str: return f"EXISTS{self.wrap(expression)}" - def case_sql(self, expression): + def case_sql(self, expression: exp.Case) -> str: this = self.sql(expression, "this") statements = [f"CASE {this}" if this else "CASE"] @@ -1043,17 +1072,17 @@ class Generator: return " ".join(statements) - def constraint_sql(self, expression): + def constraint_sql(self, expression: exp.Constraint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"CONSTRAINT {this} {expressions}" - def extract_sql(self, expression): + def extract_sql(self, expression: exp.Extract) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" - def trim_sql(self, expression): + def trim_sql(self, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") @@ -1064,16 +1093,16 @@ class Generator: else: return f"TRIM({target})" - def concat_sql(self, expression): + def concat_sql(self, expression: exp.Concat) -> str: if len(expression.expressions) == 1: return self.sql(expression.expressions[0]) return self.function_fallback_sql(expression) - def check_sql(self, expression): + def check_sql(self, expression: exp.Check) -> str: this = self.sql(expression, key="this") return f"CHECK ({this})" - def foreignkey_sql(self, expression): + def foreignkey_sql(self, expression: exp.ForeignKey) -> str: expressions = self.expressions(expression, flat=True) reference = self.sql(expression, "reference") reference = f" {reference}" if reference else "" @@ -1083,16 +1112,16 @@ class Generator: update = f" ON UPDATE {update}" if update else "" return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" - def unique_sql(self, expression): + def unique_sql(self, expression: exp.Unique) -> str: columns = self.expressions(expression, key="expressions") return f"UNIQUE ({columns})" - def if_sql(self, expression): + def if_sql(self, expression: exp.If) -> str: return self.case_sql( exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) ) - def in_sql(self, expression): + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") field = expression.args.get("field") @@ -1106,24 +1135,24 @@ class Generator: in_sql = f"({self.expressions(expression, flat=True)})" return f"{self.sql(expression, 'this')} IN {in_sql}" - def in_unnest_op(self, unnest): + def in_unnest_op(self, unnest: exp.Unnest) -> str: return f"(SELECT {self.sql(unnest)})" - def interval_sql(self, expression): + def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" return f"INTERVAL {self.sql(expression, 'this')}{unit}" - def reference_sql(self, expression): + def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"REFERENCES {this}({expressions})" - def anonymous_sql(self, expression): + def anonymous_sql(self, expression: exp.Anonymous) -> str: args = self.format_args(*expression.expressions) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" - def paren_sql(self, expression): + def paren_sql(self, expression: exp.Paren) -> str: if isinstance(expression.unnest(), exp.Select): sql = self.wrap(expression) else: @@ -1132,35 +1161,35 @@ class Generator: return self.prepend_ctes(expression, sql) - def neg_sql(self, expression): + def neg_sql(self, expression: exp.Neg) -> str: # This makes sure we don't convert "- - 5" to "--5", which is a comment this_sql = self.sql(expression, "this") sep = " " if this_sql[0] == "-" else "" return f"-{sep}{this_sql}" - def not_sql(self, expression): + def not_sql(self, expression: exp.Not) -> str: return f"NOT {self.sql(expression, 'this')}" - def alias_sql(self, expression): + def alias_sql(self, expression: exp.Alias) -> str: to_sql = self.sql(expression, "alias") to_sql = f" AS {to_sql}" if to_sql else "" return f"{self.sql(expression, 'this')}{to_sql}" - def aliases_sql(self, expression): + def aliases_sql(self, expression: exp.Aliases) -> str: return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" - def attimezone_sql(self, expression): + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: this = self.sql(expression, "this") zone = self.sql(expression, "zone") return f"{this} AT TIME ZONE {zone}" - def add_sql(self, expression): + def add_sql(self, expression: exp.Add) -> str: return self.binary(expression, "+") - def and_sql(self, expression): + def and_sql(self, expression: exp.And) -> str: return self.connector_sql(expression, "AND") - def connector_sql(self, expression, op): + def connector_sql(self, expression: exp.Connector, op: str) -> str: if not self.pretty: return self.binary(expression, op) @@ -1168,53 +1197,53 @@ class Generator: sep = "\n" if self.text_width(sqls) > self._max_text_width else " " return f"{sep}{op} ".join(sqls) - def bitwiseand_sql(self, expression): + def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: return self.binary(expression, "&") - def bitwiseleftshift_sql(self, expression): + def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str: return self.binary(expression, "<<") - def bitwisenot_sql(self, expression): + def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str: return f"~{self.sql(expression, 'this')}" - def bitwiseor_sql(self, expression): + def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str: return self.binary(expression, "|") - def bitwiserightshift_sql(self, expression): + def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str: return self.binary(expression, ">>") - def bitwisexor_sql(self, expression): + def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str: return self.binary(expression, "^") - def cast_sql(self, expression): + def cast_sql(self, expression: exp.Cast) -> str: return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - def currentdate_sql(self, expression): + def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" - def collate_sql(self, expression): + def collate_sql(self, expression: exp.Collate) -> str: return self.binary(expression, "COLLATE") - def command_sql(self, expression): + def command_sql(self, expression: exp.Command) -> str: return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" - def transaction_sql(self, *_): + def transaction_sql(self, *_) -> str: return "BEGIN" - def commit_sql(self, expression): + def commit_sql(self, expression: exp.Commit) -> str: chain = expression.args.get("chain") if chain is not None: chain = " AND CHAIN" if chain else " AND NO CHAIN" return f"COMMIT{chain or ''}" - def rollback_sql(self, expression): + def rollback_sql(self, expression: exp.Rollback) -> str: savepoint = expression.args.get("savepoint") savepoint = f" TO {savepoint}" if savepoint else "" return f"ROLLBACK{savepoint}" - def distinct_sql(self, expression): + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1222,13 +1251,13 @@ class Generator: on = f" ON {on}" if on else "" return f"DISTINCT{this}{on}" - def ignorenulls_sql(self, expression): + def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: return f"{self.sql(expression, 'this')} IGNORE NULLS" - def respectnulls_sql(self, expression): + def respectnulls_sql(self, expression: exp.RespectNulls) -> str: return f"{self.sql(expression, 'this')} RESPECT NULLS" - def intdiv_sql(self, expression): + def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( this=exp.Div(this=expression.this, expression=expression.expression), @@ -1236,79 +1265,79 @@ class Generator: ) ) - def dpipe_sql(self, expression): + def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") - def div_sql(self, expression): + def div_sql(self, expression: exp.Div) -> str: return self.binary(expression, "/") - def distance_sql(self, expression): + def distance_sql(self, expression: exp.Distance) -> str: return self.binary(expression, "<->") - def dot_sql(self, expression): + def dot_sql(self, expression: exp.Dot) -> str: return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" - def eq_sql(self, expression): + def eq_sql(self, expression: exp.EQ) -> str: return self.binary(expression, "=") - def escape_sql(self, expression): + def escape_sql(self, expression: exp.Escape) -> str: return self.binary(expression, "ESCAPE") - def gt_sql(self, expression): + def gt_sql(self, expression: exp.GT) -> str: return self.binary(expression, ">") - def gte_sql(self, expression): + def gte_sql(self, expression: exp.GTE) -> str: return self.binary(expression, ">=") - def ilike_sql(self, expression): + def ilike_sql(self, expression: exp.ILike) -> str: return self.binary(expression, "ILIKE") - def is_sql(self, expression): + def is_sql(self, expression: exp.Is) -> str: return self.binary(expression, "IS") - def like_sql(self, expression): + def like_sql(self, expression: exp.Like) -> str: return self.binary(expression, "LIKE") - def similarto_sql(self, expression): + def similarto_sql(self, expression: exp.SimilarTo) -> str: return self.binary(expression, "SIMILAR TO") - def lt_sql(self, expression): + def lt_sql(self, expression: exp.LT) -> str: return self.binary(expression, "<") - def lte_sql(self, expression): + def lte_sql(self, expression: exp.LTE) -> str: return self.binary(expression, "<=") - def mod_sql(self, expression): + def mod_sql(self, expression: exp.Mod) -> str: return self.binary(expression, "%") - def mul_sql(self, expression): + def mul_sql(self, expression: exp.Mul) -> str: return self.binary(expression, "*") - def neq_sql(self, expression): + def neq_sql(self, expression: exp.NEQ) -> str: return self.binary(expression, "<>") - def nullsafeeq_sql(self, expression): + def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str: return self.binary(expression, "IS NOT DISTINCT FROM") - def nullsafeneq_sql(self, expression): + def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str: return self.binary(expression, "IS DISTINCT FROM") - def or_sql(self, expression): + def or_sql(self, expression: exp.Or) -> str: return self.connector_sql(expression, "OR") - def sub_sql(self, expression): + def sub_sql(self, expression: exp.Sub) -> str: return self.binary(expression, "-") - def trycast_sql(self, expression): + def trycast_sql(self, expression: exp.TryCast) -> str: return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - def use_sql(self, expression): + def use_sql(self, expression: exp.Use) -> str: return f"USE {self.sql(expression, 'this')}" - def binary(self, expression, op): + def binary(self, expression: exp.Binary, op: str) -> str: return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" - def function_fallback_sql(self, expression): + def function_fallback_sql(self, expression: exp.Func) -> str: args = [] for arg_value in expression.args.values(): if isinstance(arg_value, list): @@ -1319,19 +1348,26 @@ class Generator: return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" - def format_args(self, *args): - args = tuple(self.sql(arg) for arg in args if arg is not None) - if self.pretty and self.text_width(args) > self._max_text_width: - return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True) - return ", ".join(args) + def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: + arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) + if self.pretty and self.text_width(arg_sqls) > self._max_text_width: + return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) + return ", ".join(arg_sqls) - def text_width(self, args): + def text_width(self, args: t.Iterable) -> int: return sum(len(arg) for arg in args) - def format_time(self, expression): + def format_time(self, expression: exp.Expression) -> t.Optional[str]: return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) - def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): + def expressions( + self, + expression: exp.Expression, + key: t.Optional[str] = None, + flat: bool = False, + indent: bool = True, + sep: str = ", ", + ) -> str: expressions = expression.args.get(key or "expressions") if not expressions: @@ -1359,45 +1395,67 @@ class Generator: else: result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") - result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) - return self.indent(result_sqls, skip_first=False) if indent else result_sqls + result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) + return self.indent(result_sql, skip_first=False) if indent else result_sql - def op_expressions(self, op, expression, flat=False): + def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: expressions_sql = self.expressions(expression, flat=flat) if flat: return f"{op} {expressions_sql}" return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" - def naked_property(self, expression): + def naked_property(self, expression: exp.Property) -> str: property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) if not property_name: self.unsupported(f"Unsupported property {expression.__class__.__name__}") return f"{property_name} {self.sql(expression, 'this')}" - def set_operation(self, expression, op): + def set_operation(self, expression: exp.Expression, op: str) -> str: this = self.sql(expression, "this") op = self.seg(op) return self.query_modifiers( expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" ) - def token_sql(self, token_type): + def token_sql(self, token_type: TokenType) -> str: return self.TOKEN_MAPPING.get(token_type, token_type.name) - def userdefinedfunction_sql(self, expression): + def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") expressions = self.no_identify(lambda: self.expressions(expression)) return f"{this}({expressions})" - def userdefinedfunctionkwarg_sql(self, expression): + def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind") return f"{this} {kind}" - def joinhint_sql(self, expression): + def joinhint_sql(self, expression: exp.JoinHint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"{this}({expressions})" - def kwarg_sql(self, expression): + def kwarg_sql(self, expression: exp.Kwarg) -> str: return self.binary(expression, "=>") + + def when_sql(self, expression: exp.When) -> str: + this = self.sql(expression, "this") + then_expression = expression.args.get("then") + if isinstance(then_expression, exp.Insert): + then = f"INSERT {self.sql(then_expression, 'this')}" + if "expression" in then_expression.args: + then += f" VALUES {self.sql(then_expression, 'expression')}" + elif isinstance(then_expression, exp.Update): + if isinstance(then_expression.args.get("expressions"), exp.Star): + then = f"UPDATE {self.sql(then_expression, 'expressions')}" + else: + then = f"UPDATE SET {self.expressions(then_expression, flat=True)}" + else: + then = self.sql(then_expression) + return f"WHEN {this} THEN {then}" + + def merge_sql(self, expression: exp.Merge) -> str: + this = self.sql(expression, "this") + using = f"USING {self.sql(expression, 'using')}" + on = f"ON {self.sql(expression, 'on')}" + return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}" |