diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 368 |
1 files changed, 245 insertions, 123 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d7dcea0..f1ec398 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -31,6 +31,8 @@ class Generator: hex_end (str): specifies which ending character to use to delimit hex literals. Default: None. byte_start (str): specifies which starting character to use to delimit byte literals. Default: None. byte_end (str): specifies which ending character to use to delimit byte literals. Default: None. + raw_start (str): specifies which starting character to use to delimit raw literals. Default: None. + raw_end (str): specifies which ending character to use to delimit raw literals. Default: None. identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. normalize (bool): if set to True all identifiers will lower cased string_escape (str): specifies a string escape character. Default: '. @@ -76,11 +78,12 @@ class Generator: exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.MaterializedProperty: lambda self, e: "MATERIALIZED", exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", - exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS", + exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", + exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", - exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY", + exp.TemporaryProperty: lambda self, e: f"TEMPORARY", exp.TransientProperty: lambda self, e: "TRANSIENT", exp.StabilityProperty: lambda self, e: e.name, exp.VolatileProperty: lambda self, e: "VOLATILE", @@ -133,6 +136,15 @@ class Generator: # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" + # Whether a table is allowed to be renamed with a db + RENAME_TABLE_WITH_DB = True + + # The separator for grouping sets and rollups + GROUPINGS_SEP = "," + + # The string used for creating index on a table + INDEX_ON = "ON" + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -167,7 +179,6 @@ class Generator: PARAMETER_TOKEN = "@" PROPERTIES_LOCATION = { - exp.AfterJournalProperty: exp.Properties.Location.POST_NAME, exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, @@ -196,7 +207,9 @@ class Generator: exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, + exp.Order: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, exp.Property: exp.Properties.Location.POST_WITH, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, @@ -204,13 +217,15 @@ class Generator: exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, + exp.Set: exp.Properties.Location.POST_SCHEMA, + exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, - exp.TableFormatProperty: exp.Properties.Location.POST_WITH, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, exp.TransientProperty: exp.Properties.Location.POST_CREATE, + exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, @@ -221,7 +236,7 @@ class Generator: RESERVED_KEYWORDS: t.Set[str] = set() WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) - UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column) + UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -239,6 +254,8 @@ class Generator: "hex_end", "byte_start", "byte_end", + "raw_start", + "raw_end", "identify", "normalize", "string_escape", @@ -276,6 +293,8 @@ class Generator: hex_end=None, byte_start=None, byte_end=None, + raw_start=None, + raw_end=None, identify=False, normalize=False, string_escape=None, @@ -308,6 +327,8 @@ class Generator: self.hex_end = hex_end self.byte_start = byte_start self.byte_end = byte_end + self.raw_start = raw_start + self.raw_end = raw_end self.identify = identify self.normalize = normalize self.string_escape = string_escape or "'" @@ -399,7 +420,11 @@ class Generator: return sql if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"{comments_sql}{self.sep()}{sql}" + return ( + f"{self.sep()}{comments_sql}{sql}" + if sql[0].isspace() + else f"{comments_sql}{self.sep()}{sql}" + ) return f"{sql} {comments_sql}" @@ -567,7 +592,9 @@ class Generator: ) -> str: this = "" if expression.this is not None: - this = " ALWAYS " if expression.this else " BY DEFAULT " + on_null = "ON NULL " if expression.args.get("on_null") else "" + this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}" + start = expression.args.get("start") start = f"START WITH {start}" if start else "" increment = expression.args.get("increment") @@ -578,14 +605,20 @@ class Generator: maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" cycle = expression.args.get("cycle") cycle_sql = "" + if cycle is not None: cycle_sql = f"{' NO' if not cycle else ''} CYCLE" cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql + sequence_opts = "" if start or increment or cycle_sql: sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" sequence_opts = f" ({sequence_opts.strip()})" - return f"GENERATED{this}AS IDENTITY{sequence_opts}" + + expr = self.sql(expression, "expression") + expr = f"({expr})" if expr else "IDENTITY" + + return f"GENERATED{this}AS {expr}{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -596,8 +629,10 @@ class Generator: return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" return f"PRIMARY KEY" - def uniquecolumnconstraint_sql(self, _) -> str: - return "UNIQUE" + def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"UNIQUE{this}" def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() @@ -653,33 +688,9 @@ class Generator: prefix=" ", ) - indexes = expression.args.get("indexes") - if indexes: - indexes_sql: t.List[str] = [] - for index in indexes: - ind_unique = " UNIQUE" if index.args.get("unique") else "" - ind_primary = " PRIMARY" if index.args.get("primary") else "" - ind_amp = " AMP" if index.args.get("amp") else "" - ind_name = f" {index.name}" if index.name else "" - ind_columns = ( - f' ({self.expressions(index, key="columns", flat=True)})' - if index.args.get("columns") - else "" - ) - ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" - - if indexes_sql: - indexes_sql.append(ind_sql) - else: - indexes_sql.append( - f"{ind_sql}{postindex_props_sql}" - if index.args.get("primary") - else f"{postindex_props_sql}{ind_sql}" - ) - - index_sql = "".join(indexes_sql) - else: - index_sql = postindex_props_sql + indexes = self.expressions(expression, key="indexes", indent=False, sep=" ") + indexes = f" {indexes}" if indexes else "" + index_sql = indexes + postindex_props_sql replace = " OR REPLACE" if expression.args.get("replace") else "" unique = " UNIQUE" if expression.args.get("unique") else "" @@ -711,9 +722,23 @@ class Generator: " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" ) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}" + clone = self.sql(expression, "clone") + clone = f" {clone}" if clone else "" + + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}" return self.prepend_ctes(expression, expression_sql) + def clone_sql(self, expression: exp.Clone) -> str: + this = self.sql(expression, "this") + when = self.sql(expression, "when") + + if when: + kind = self.sql(expression, "kind") + expr = self.sql(expression, "expression") + return f"CLONE {this} {when} ({kind} => {expr})" + + return f"CLONE {this}" + def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" @@ -757,6 +782,17 @@ class Generator: return f"{self.byte_start}{this}{self.byte_end}" return this + def rawstring_sql(self, expression: exp.RawString) -> str: + if self.raw_start: + return f"{self.raw_start}{expression.name}{self.raw_end}" + return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\"))) + + def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: + this = self.sql(expression, "this") + specifier = self.sql(expression, "expression") + specifier = f" {specifier}" if specifier else "" + return f"{this}{specifier}" + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) @@ -768,7 +804,8 @@ class Generator: nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" if expression.args.get("values") is not None: delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") - values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}" + values = self.expressions(expression, key="values", flat=True) + values = f"{delimiters[0]}{values}{delimiters[1]}" else: nested = f"({interior})" @@ -836,10 +873,17 @@ class Generator: return "" def index_sql(self, expression: exp.Index) -> str: - this = self.sql(expression, "this") + unique = "UNIQUE " if expression.args.get("unique") else "" + primary = "PRIMARY " if expression.args.get("primary") else "" + amp = "AMP " if expression.args.get("amp") else "" + name = f"{expression.name} " if expression.name else "" table = self.sql(expression, "table") - columns = self.sql(expression, "columns") - return f"{this} ON {table} {columns}" + table = f"{self.INDEX_ON} {table} " if table else "" + index = "INDEX " if not table else "" + columns = self.expressions(expression, key="columns", flat=True) + partition_by = self.expressions(expression, key="partition_by", flat=True) + partition_by = f" PARTITION BY {partition_by}" if partition_by else "" + return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name @@ -861,8 +905,9 @@ class Generator: output_format = f"OUTPUTFORMAT {output_format}" if output_format else "" return self.sep().join((input_format, output_format)) - def national_sql(self, expression: exp.National) -> str: - return f"N{self.sql(expression, 'this')}" + def national_sql(self, expression: exp.National, prefix: str = "N") -> str: + string = self.sql(exp.Literal.string(expression.name)) + return f"{prefix}{string}" def partition_sql(self, expression: exp.Partition) -> str: return f"PARTITION({self.expressions(expression)})" @@ -955,23 +1000,18 @@ class Generator: def journalproperty_sql(self, expression: exp.JournalProperty) -> str: no = "NO " if expression.args.get("no") else "" + local = expression.args.get("local") + local = f"{local} " if local else "" dual = "DUAL " if expression.args.get("dual") else "" before = "BEFORE " if expression.args.get("before") else "" - return f"{no}{dual}{before}JOURNAL" + after = "AFTER " if expression.args.get("after") else "" + return f"{no}{local}{dual}{before}{after}JOURNAL" def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str: freespace = self.sql(expression, "this") percent = " PERCENT" if expression.args.get("percent") else "" return f"FREESPACE={freespace}{percent}" - def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str: - no = "NO " if expression.args.get("no") else "" - dual = "DUAL " if expression.args.get("dual") else "" - local = "" - if expression.args.get("local") is not None: - local = "LOCAL " if expression.args.get("local") else "NOT LOCAL " - return f"{no}{dual}{local}AFTER JOURNAL" - def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str: if expression.args.get("default"): property = "DEFAULT" @@ -992,19 +1032,19 @@ class Generator: def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str: default = expression.args.get("default") - min = expression.args.get("min") - if default is not None or min is not None: + minimum = expression.args.get("minimum") + maximum = expression.args.get("maximum") + if default or minimum or maximum: if default: - property = "DEFAULT" - elif min: - property = "MINIMUM" + prop = "DEFAULT" + elif minimum: + prop = "MINIMUM" else: - property = "MAXIMUM" - return f"{property} DATABLOCKSIZE" - else: - units = expression.args.get("units") - units = f" {units}" if units else "" - return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" + prop = "MAXIMUM" + return f"{prop} DATABLOCKSIZE" + units = expression.args.get("units") + units = f" {units}" if units else "" + return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str: autotemp = expression.args.get("autotemp") @@ -1014,16 +1054,16 @@ class Generator: never = expression.args.get("never") if autotemp is not None: - property = f"AUTOTEMP({self.expressions(autotemp)})" + prop = f"AUTOTEMP({self.expressions(autotemp)})" elif always: - property = "ALWAYS" + prop = "ALWAYS" elif default: - property = "DEFAULT" + prop = "DEFAULT" elif manual: - property = "MANUAL" + prop = "MANUAL" elif never: - property = "NEVER" - return f"BLOCKCOMPRESSION={property}" + prop = "NEVER" + return f"BLOCKCOMPRESSION={prop}" def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str: no = expression.args.get("no") @@ -1138,21 +1178,24 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" - hints = self.expressions(expression, key="hints", sep=", ", flat=True) + hints = self.expressions(expression, key="hints", flat=True) hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else "" - laterals = self.expressions(expression, key="laterals", sep="") + pivots = self.expressions(expression, key="pivots", sep=" ", flat=True) + pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="") - pivots = self.expressions(expression, key="pivots", sep="") + laterals = self.expressions(expression, key="laterals", sep="") system_time = expression.args.get("system_time") system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" + return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: if self.alias_post_tablesample and expression.this.alias: - this = self.sql(expression.this, "this") + table = expression.this.copy() + table.set("alias", None) + this = self.sql(table) alias = f"{sep}{self.sql(expression.this, 'alias')}" else: this = self.sql(expression, "this") @@ -1177,14 +1220,22 @@ class Generator: return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}" def pivot_sql(self, expression: exp.Pivot) -> str: - this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + + if expression.this: + this = self.sql(expression, "this") + on = f"{self.seg('ON')} {expressions}" + using = self.expressions(expression, key="using", flat=True) + using = f"{self.seg('USING')} {using}" if using else "" + group = self.sql(expression, "group") + return f"PIVOT {this}{on}{using}{group}" + alias = self.sql(expression, "alias") alias = f" AS {alias}" if alias else "" unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" - expressions = self.expressions(expression, key="expressions") field = self.sql(expression, "field") - return f"{this} {direction}({expressions} FOR {field}){alias}" + return f"{direction}({expressions} FOR {field}){alias}" def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" @@ -1218,8 +1269,7 @@ class Generator: return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" def from_sql(self, expression: exp.From) -> str: - expressions = self.expressions(expression, flat=True) - return f"{self.seg('FROM')} {expressions}" + return f"{self.seg('FROM')} {self.sql(expression, 'this')}" def group_sql(self, expression: exp.Group) -> str: group_by = self.op_expressions("GROUP BY", expression) @@ -1242,10 +1292,16 @@ class Generator: rollup_sql = self.expressions(expression, key="rollup", indent=False) rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else "" - groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",") + groupings = csv( + grouping_sets, + cube_sql, + rollup_sql, + self.seg("WITH TOTALS") if expression.args.get("totals") else "", + sep=self.GROUPINGS_SEP, + ) if expression.args.get("expressions") and groupings: - group_by = f"{group_by}," + group_by = f"{group_by}{self.GROUPINGS_SEP}" return f"{group_by}{groupings}" @@ -1254,18 +1310,16 @@ class Generator: return f"{self.seg('HAVING')}{self.sep()}{this}" def join_sql(self, expression: exp.Join) -> str: - op_sql = self.seg( - " ".join( - op - for op in ( - "NATURAL" if expression.args.get("natural") else None, - expression.side, - expression.kind, - expression.hint if self.JOIN_HINTS else None, - "JOIN", - ) - if op + op_sql = " ".join( + op + for op in ( + "NATURAL" if expression.args.get("natural") else None, + "GLOBAL" if expression.args.get("global") else None, + expression.side, + expression.kind, + expression.hint if self.JOIN_HINTS else None, ) + if op ) on_sql = self.sql(expression, "on") using = expression.args.get("using") @@ -1273,6 +1327,8 @@ class Generator: if not on_sql and using: on_sql = csv(*(self.sql(column) for column in using)) + this_sql = self.sql(expression, "this") + if on_sql: on_sql = self.indent(on_sql, skip_first=True) space = self.seg(" " * self.pad) if self.pretty else " " @@ -1280,10 +1336,11 @@ class Generator: on_sql = f"{space}USING ({on_sql})" else: on_sql = f"{space}ON {on_sql}" + elif not op_sql: + return f", {this_sql}" - expression_sql = self.sql(expression, "expression") - this_sql = self.sql(expression, "this") - return f"{expression_sql}{op_sql} {this_sql}{on_sql}" + op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" + return f"{self.seg(op_sql)} {this_sql}{on_sql}" def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) @@ -1336,12 +1393,22 @@ class Generator: return f"PRAGMA {self.sql(expression, 'this')}" def lock_sql(self, expression: exp.Lock) -> str: - if self.LOCKING_READS_SUPPORTED: - lock_type = "UPDATE" if expression.args["update"] else "SHARE" - return self.seg(f"FOR {lock_type}") + if not self.LOCKING_READS_SUPPORTED: + self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") + return "" - self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") - return "" + lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE" + expressions = self.expressions(expression, flat=True) + expressions = f" OF {expressions}" if expressions else "" + wait = expression.args.get("wait") + + if wait is not None: + if isinstance(wait, exp.Literal): + wait = f" WAIT {self.sql(wait)}" + else: + wait = " NOWAIT" if wait else " SKIP LOCKED" + + return f"{lock_type}{expressions}{wait or ''}" def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" @@ -1460,27 +1527,33 @@ class Generator: return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("joins") or []], + *[self.sql(join) for join in expression.args.get("joins") or []], self.sql(expression, "match"), - *[self.sql(sql) for sql in expression.args.get("laterals") or []], + *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), - self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) - if expression.args.get("windows") - else "", - self.sql(expression, "distribute"), - self.sql(expression, "sort"), - self.sql(expression, "cluster"), + *self.after_having_modifiers(expression), self.sql(expression, "order"), self.sql(expression, "offset") if fetch else self.sql(limit), self.sql(limit) if fetch else self.sql(expression, "offset"), - self.sql(expression, "lock"), - self.sql(expression, "sample"), + *self.after_limit_modifiers(expression), sep="", ) + def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: + return [ + self.sql(expression, "qualify"), + self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) + if expression.args.get("windows") + else "", + ] + + def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: + locks = self.expressions(expression, key="locks", sep=" ") + locks = f" {locks}" if locks else "" + return [locks, self.sql(expression, "sample")] + def select_sql(self, expression: exp.Select) -> str: hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") @@ -1529,13 +1602,10 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" - sql = self.query_modifiers( - expression, - self.wrap(expression), - alias, - self.expressions(expression, key="pivots", sep=" "), - ) + pivots = self.expressions(expression, key="pivots", sep=" ", flat=True) + pivots = f" {pivots}" if pivots else "" + sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots) return self.prepend_ctes(expression, sql) def qualify_sql(self, expression: exp.Qualify) -> str: @@ -1712,10 +1782,6 @@ class Generator: options = f" {options}" if options else "" return f"PRIMARY KEY ({expressions}){options}" - def unique_sql(self, expression: exp.Unique) -> str: - columns = self.expressions(expression, key="expressions") - return f"UNIQUE ({columns})" - def if_sql(self, expression: exp.If) -> str: return self.case_sql( exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) @@ -1745,6 +1811,26 @@ class Generator: encoding = f" ENCODING {encoding}" if encoding else "" return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})" + def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + path = self.sql(expression, "path") + path = f" {path}" if path else "" + as_json = " AS JSON" if expression.args.get("as_json") else "" + return f"{this} {kind}{path}{as_json}" + + def openjson_sql(self, expression: exp.OpenJSON) -> str: + this = self.sql(expression, "this") + path = self.sql(expression, "path") + path = f", {path}" if path else "" + expressions = self.expressions(expression) + with_ = ( + f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}" + if expressions + else "" + ) + return f"OPENJSON({this}{path}){with_}" + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") @@ -1773,7 +1859,7 @@ class Generator: if self.SINGLE_STRING_INTERVAL: this = expression.this.name if expression.this else "" - return f"INTERVAL '{this}{unit}'" + return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}" this = self.sql(expression, "this") if this: @@ -1883,6 +1969,28 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}" + def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str: + this = self.sql(expression, "this") + delete = " DELETE" if expression.args.get("delete") else "" + recompress = self.sql(expression, "recompress") + recompress = f" RECOMPRESS {recompress}" if recompress else "" + to_disk = self.sql(expression, "to_disk") + to_disk = f" TO DISK {to_disk}" if to_disk else "" + to_volume = self.sql(expression, "to_volume") + to_volume = f" TO VOLUME {to_volume}" if to_volume else "" + return f"{this}{delete}{recompress}{to_disk}{to_volume}" + + def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str: + where = self.sql(expression, "where") + group = self.sql(expression, "group") + aggregates = self.expressions(expression, key="aggregates") + aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else "" + + if not (where or group or aggregates) and len(expression.expressions) == 1: + return f"TTL {self.expressions(expression, flat=True)}" + + return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}" + def transaction_sql(self, expression: exp.Transaction) -> str: return "BEGIN" @@ -1919,6 +2027,11 @@ class Generator: return f"ALTER COLUMN {this} DROP DEFAULT" def renametable_sql(self, expression: exp.RenameTable) -> str: + if not self.RENAME_TABLE_WITH_DB: + # Remove db from tables + expression = expression.transform( + lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n + ) this = self.sql(expression, "this") return f"RENAME TO {this}" @@ -2208,3 +2321,12 @@ class Generator: self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function") return self.sql(exp.cast(expression.this, "text")) + + +def cached_generator( + cache: t.Optional[t.Dict[int, str]] = None +) -> t.Callable[[exp.Expression], str]: + """Returns a cached generator.""" + cache = {} if cache is None else cache + generator = Generator(normalize=True, identify="safe") + return lambda e: generator.generate(e, cache) |