diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/generator.py | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 149 |
1 files changed, 120 insertions, 29 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index f8d7d68..306df81 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -8,7 +8,7 @@ from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time -from sqlglot.tokens import TokenType +from sqlglot.tokens import Tokenizer, TokenType logger = logging.getLogger("sqlglot") @@ -61,6 +61,7 @@ class Generator: exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", + exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS", exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", @@ -78,7 +79,10 @@ class Generator: exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.MaterializedProperty: lambda self, e: "MATERIALIZED", exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", + exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", + exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION", exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", + exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -171,6 +175,9 @@ class Generator: # Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax TZ_TO_WITH_TIME_ZONE = False + # Whether or not the NVL2 function is supported + NVL2_SUPPORTED = True + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") @@ -179,6 +186,9 @@ class Generator: # SELECT * VALUES into SELECT UNION VALUES_AS_TABLE = True + # Whether or not the word COLUMN is included when adding a column with ALTER TABLE + ALTER_TABLE_ADD_COLUMN_KEYWORD = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -245,6 +255,7 @@ class Generator: exp.MaterializedProperty: exp.Properties.Location.POST_CREATE, exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, + exp.OnProperty: exp.Properties.Location.POST_SCHEMA, exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, exp.Order: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, @@ -317,8 +328,7 @@ class Generator: QUOTE_END = "'" IDENTIFIER_START = '"' IDENTIFIER_END = '"' - STRING_ESCAPE = "'" - IDENTIFIER_ESCAPE = '"' + TOKENIZER_CLASS = Tokenizer # Delimiters for bit, hex, byte and raw literals BIT_START: t.Optional[str] = None @@ -379,8 +389,10 @@ class Generator: ) self.unsupported_messages: t.List[str] = [] - self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END - self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END + self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END + self._escaped_identifier_end: str = ( + self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END + ) self._cache: t.Optional[t.Dict[int, str]] = None def generate( @@ -626,6 +638,16 @@ class Generator: kind_sql = self.sql(expression, "kind").strip() return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql + def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: + this = self.sql(expression, "this") + if expression.args.get("not_null"): + persisted = " PERSISTED NOT NULL" + elif expression.args.get("persisted"): + persisted = " PERSISTED" + else: + persisted = "" + return f"AS {this}{persisted}" + def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) @@ -642,8 +664,8 @@ class Generator: ) -> str: this = "" if expression.this is not None: - on_null = "ON NULL " if expression.args.get("on_null") else "" - this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}" + on_null = " ON NULL" if expression.args.get("on_null") else "" + this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}" start = expression.args.get("start") start = f"START WITH {start}" if start else "" @@ -668,7 +690,7 @@ class Generator: expr = self.sql(expression, "expression") expr = f"({expr})" if expr else "IDENTITY" - return f"GENERATED{this}AS {expr}{sequence_opts}" + return f"GENERATED{this} AS {expr}{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -774,14 +796,16 @@ class Generator: def clone_sql(self, expression: exp.Clone) -> str: this = self.sql(expression, "this") + shallow = "SHALLOW " if expression.args.get("shallow") else "" + this = f"{shallow}CLONE {this}" when = self.sql(expression, "when") if when: kind = self.sql(expression, "kind") expr = self.sql(expression, "expression") - return f"CLONE {this} {when} ({kind} => {expr})" + return f"{this} {when} ({kind} => {expr})" - return f"CLONE {this}" + return this def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" @@ -830,7 +854,7 @@ class Generator: string = self.escape_str(expression.this.replace("\\", "\\\\")) return f"{self.QUOTE_START}{string}{self.QUOTE_END}" - def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: + def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: this = self.sql(expression, "this") specifier = self.sql(expression, "expression") specifier = f" {specifier}" if specifier else "" @@ -839,11 +863,14 @@ class Generator: def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this - type_sql = ( - self.TYPE_MAPPING.get(type_value, type_value.value) - if isinstance(type_value, exp.DataType.Type) - else type_value - ) + if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"): + type_sql = self.sql(expression, "kind") + else: + type_sql = ( + self.TYPE_MAPPING.get(type_value, type_value.value) + if isinstance(type_value, exp.DataType.Type) + else type_value + ) nested = "" interior = self.expressions(expression, flat=True) @@ -943,9 +970,9 @@ class Generator: name = self.sql(expression, "this") name = f"{name} " if name else "" table = self.sql(expression, "table") - table = f"{self.INDEX_ON} {table} " if table else "" + table = f"{self.INDEX_ON} {table}" if table else "" using = self.sql(expression, "using") - using = f"USING {using} " if using else "" + using = f" USING {using} " if using else "" index = "INDEX " if not table else "" columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" @@ -1171,6 +1198,7 @@ class Generator: where = f"{self.sep()}REPLACE WHERE {where}" if where else "" expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" conflict = self.sql(expression, "conflict") + by_name = " BY NAME" if expression.args.get("by_name") else "" returning = self.sql(expression, "returning") if self.RETURNING_END: @@ -1178,7 +1206,7 @@ class Generator: else: expression_sql = f"{returning}{expression_sql}{conflict}" - sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}" + sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1196,6 +1224,9 @@ class Generator: def pseudotype_sql(self, expression: exp.PseudoType) -> str: return expression.name.upper() + def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str: + return expression.name.upper() + def onconflict_sql(self, expression: exp.OnConflict) -> str: conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" constraint = self.sql(expression, "constraint") @@ -1248,6 +1279,8 @@ class Generator: if part ) + version = self.sql(expression, "version") + version = f" {version}" if version else "" alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" hints = self.expressions(expression, key="hints", sep=" ") @@ -1256,10 +1289,8 @@ class Generator: pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="", skip_first=True) laterals = self.expressions(expression, key="laterals", sep="") - system_time = expression.args.get("system_time") - system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" + return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1314,6 +1345,12 @@ class Generator: nulls = "" return f"{direction}{nulls}({expressions} FOR {field}){alias}" + def version_sql(self, expression: exp.Version) -> str: + this = f"FOR {expression.name}" + kind = expression.text("kind") + expr = self.sql(expression, "expression") + return f"{this} {kind} {expr}" + def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" @@ -1323,12 +1360,13 @@ class Generator: from_sql = self.sql(expression, "from") where_sql = self.sql(expression, "where") returning = self.sql(expression, "returning") + order = self.sql(expression, "order") limit = self.sql(expression, "limit") if self.RETURNING_END: - expression_sql = f"{from_sql}{where_sql}{returning}{limit}" + expression_sql = f"{from_sql}{where_sql}{returning}" else: - expression_sql = f"{returning}{from_sql}{where_sql}{limit}" - sql = f"UPDATE {this} SET {set_sql}{expression_sql}" + expression_sql = f"{returning}{from_sql}{where_sql}" + sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}" return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: @@ -1425,6 +1463,16 @@ class Generator: this = self.indent(self.sql(expression, "this")) return f"{self.seg('HAVING')}{self.sep()}{this}" + def connect_sql(self, expression: exp.Connect) -> str: + start = self.sql(expression, "start") + start = self.seg(f"START WITH {start}") if start else "" + connect = self.sql(expression, "connect") + connect = self.seg(f"CONNECT BY {connect}") + return start + connect + + def prior_sql(self, expression: exp.Prior) -> str: + return f"PRIOR {self.sql(expression, 'this')}" + def join_sql(self, expression: exp.Join) -> str: op_sql = " ".join( op @@ -1667,6 +1715,7 @@ class Generator: return csv( *sqls, *[self.sql(join) for join in expression.args.get("joins") or []], + self.sql(expression, "connect"), self.sql(expression, "match"), *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], self.sql(expression, "where"), @@ -1801,7 +1850,8 @@ class Generator: def union_op(self, expression: exp.Union) -> str: kind = " DISTINCT" if self.EXPLICIT_UNION else "" kind = kind if expression.args.get("distinct") else " ALL" - return f"UNION{kind}" + by_name = " BY NAME" if expression.args.get("by_name") else "" + return f"UNION{kind}{by_name}" def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) @@ -2224,7 +2274,14 @@ class Generator: actions = expression.args["actions"] if isinstance(actions[0], exp.ColumnDef): - actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ") + if self.ALTER_TABLE_ADD_COLUMN_KEYWORD: + actions = self.expressions( + expression, + key="actions", + prefix="ADD COLUMN ", + ) + else: + actions = f"ADD {self.expressions(expression, key='actions')}" elif isinstance(actions[0], exp.Schema): actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Delete): @@ -2525,10 +2582,21 @@ class Generator: return f"WHEN {matched}{source}{condition} THEN {then}" def merge_sql(self, expression: exp.Merge) -> str: - this = self.sql(expression, "this") + table = expression.this + table_alias = "" + + hints = table.args.get("hints") + if hints and table.alias and isinstance(hints[0], exp.WithTableHint): + # T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias] + table = table.copy() + table_alias = f" AS {self.sql(table.args['alias'].pop())}" + + this = self.sql(table) using = f"USING {self.sql(expression, 'using')}" on = f"ON {self.sql(expression, 'on')}" - return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}" + expressions = self.expressions(expression, sep=" ") + + return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}" def tochar_sql(self, expression: exp.ToChar) -> str: if expression.args.get("format"): @@ -2631,6 +2699,29 @@ class Generator: options = f" {options}" if options else "" return f"{kind}{this}{type_}{schema}{options}" + def nvl2_sql(self, expression: exp.Nvl2) -> str: + if self.NVL2_SUPPORTED: + return self.function_fallback_sql(expression) + + case = exp.Case().when( + expression.this.is_(exp.null()).not_(copy=False), + expression.args["true"].copy(), + copy=False, + ) + else_cond = expression.args.get("false") + if else_cond: + case.else_(else_cond.copy(), copy=False) + + return self.sql(case) + + def comprehension_sql(self, expression: exp.Comprehension) -> str: + this = self.sql(expression, "this") + expr = self.sql(expression, "expression") + iterator = self.sql(expression, "iterator") + condition = self.sql(expression, "condition") + condition = f" IF {condition}" if condition else "" + return f"{this} FOR {expr} IN {iterator}{condition}" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None |