From 67578a7602a5be7eb51f324086c8d49bcf8b7498 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 16 Jun 2023 11:41:18 +0200 Subject: Merging upstream version 16.2.1. Signed-off-by: Daniel Baumann --- sqlglot/generator.py | 383 +++++++++++++++++++++++++++------------------------ 1 file changed, 200 insertions(+), 183 deletions(-) (limited to 'sqlglot/generator.py') diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 97cbe15..d3cf9f0 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot") class Generator: """ - Generator interprets the given syntax tree and produces a SQL string as an output. + Generator converts a given syntax tree to the corresponding SQL string. Args: - time_mapping (dict): the dictionary of custom time mappings in which the key - represents a python time format and the output the target time format - time_trie (trie): a trie of the time_mapping keys - pretty (bool): if set to True the returned string will be formatted. Default: False. - quote_start (str): specifies which starting character to use to delimit quotes. Default: '. - quote_end (str): specifies which ending character to use to delimit quotes. Default: '. - identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". - identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". - bit_start (str): specifies which starting character to use to delimit bit literals. Default: None. - bit_end (str): specifies which ending character to use to delimit bit literals. Default: None. - hex_start (str): specifies which starting character to use to delimit hex literals. Default: None. - hex_end (str): specifies which ending character to use to delimit hex literals. Default: None. - byte_start (str): specifies which starting character to use to delimit byte literals. Default: None. - byte_end (str): specifies which ending character to use to delimit byte literals. Default: None. - raw_start (str): specifies which starting character to use to delimit raw literals. Default: None. - raw_end (str): specifies which ending character to use to delimit raw literals. Default: None. - identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. - normalize (bool): if set to True all identifiers will lower cased - string_escape (str): specifies a string escape character. Default: '. - identifier_escape (str): specifies an identifier escape character. Default: ". - pad (int): determines padding in a formatted string. Default: 2. - indent (int): determines the size of indentation in a formatted string. Default: 4. - unnest_column_only (bool): if true unnest table aliases are considered only as column aliases - normalize_functions (str): normalize function names, "upper", "lower", or None - Default: "upper" - alias_post_tablesample (bool): if the table alias comes after tablesample - Default: False - identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit - Default: False - unsupported_level (ErrorLevel): determines the generator's behavior when it encounters - unsupported expressions. Default ErrorLevel.WARN. - null_ordering (str): Indicates the default null ordering method to use if not explicitly set. - Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". - Default: "nulls_are_small" - max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. + pretty: Whether or not to format the produced SQL string. + Default: False. + identify: Determines when an identifier should be quoted. Possible values are: + False (default): Never quote, except in cases where it's mandatory by the dialect. + True or 'always': Always quote. + 'safe': Only quote identifiers that are case insensitive. + normalize: Whether or not to normalize identifiers to lowercase. + Default: False. + pad: Determines the pad size in a formatted string. + Default: 2. + indent: Determines the indentation size in a formatted string. + Default: 2. + normalize_functions: Whether or not to normalize all function names. Possible values are: + "upper" or True (default): Convert names to uppercase. + "lower": Convert names to lowercase. + False: Disables function name normalization. + unsupported_level: Determines the generator's behavior when it encounters unsupported expressions. + Default ErrorLevel.WARN. + max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 - leading_comma (bool): if the the comma is leading or trailing in select statements + leading_comma: Determines whether or not the comma is leading or trailing in select expressions. + This is only relevant when generating in pretty mode. Default: False max_text_width: The max number of characters in a segment before creating new lines in pretty mode. The default is on the smaller end because the length only represents a segment and not the true @@ -86,6 +71,7 @@ class Generator: exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.TemporaryProperty: lambda self, e: f"TEMPORARY", + exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", exp.TransientProperty: lambda self, e: "TRANSIENT", exp.StabilityProperty: lambda self, e: e.name, exp.VolatileProperty: lambda self, e: "VOLATILE", @@ -138,15 +124,24 @@ class Generator: # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" - # Whether a table is allowed to be renamed with a db + # Whether or not a table is allowed to be renamed with a db RENAME_TABLE_WITH_DB = True # The separator for grouping sets and rollups GROUPINGS_SEP = "," - # The string used for creating index on a table + # The string used for creating an index on a table INDEX_ON = "ON" + # Whether or not join hints should be generated + JOIN_HINTS = True + + # Whether or not table hints should be generated + TABLE_HINTS = True + + # Whether or not comparing against booleans (e.g. x IS TRUE) is supported + IS_BOOL_ALLOWED = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -228,6 +223,7 @@ class Generator: exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, + exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA, exp.TransientProperty: exp.Properties.Location.POST_CREATE, exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.POST_CREATE, @@ -235,128 +231,110 @@ class Generator: exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } - JOIN_HINTS = True - TABLE_HINTS = True - IS_BOOL = True - + # Keywords that can't be used as unquoted identifier names RESERVED_KEYWORDS: t.Set[str] = set() - WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) - UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren) + + # Expressions whose comments are separated from them for better formatting + WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Select, + exp.From, + exp.Where, + exp.With, + ) + + # Expressions that can remain unwrapped when appearing in the context of an INTERVAL + UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Column, + exp.Literal, + exp.Neg, + exp.Paren, + ) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" + # Autofilled + INVERSE_TIME_MAPPING: t.Dict[str, str] = {} + INVERSE_TIME_TRIE: t.Dict = {} + INDEX_OFFSET = 0 + UNNEST_COLUMN_ONLY = False + ALIAS_POST_TABLESAMPLE = False + IDENTIFIERS_CAN_START_WITH_DIGIT = False + STRICT_STRING_CONCAT = False + NORMALIZE_FUNCTIONS: bool | str = "upper" + NULL_ORDERING = "nulls_are_small" + + # Delimiters for quotes, identifiers and the corresponding escape characters + QUOTE_START = "'" + QUOTE_END = "'" + IDENTIFIER_START = '"' + IDENTIFIER_END = '"' + STRING_ESCAPE = "'" + IDENTIFIER_ESCAPE = '"' + + # Delimiters for bit, hex, byte and raw literals + BIT_START: t.Optional[str] = None + BIT_END: t.Optional[str] = None + HEX_START: t.Optional[str] = None + HEX_END: t.Optional[str] = None + BYTE_START: t.Optional[str] = None + BYTE_END: t.Optional[str] = None + RAW_START: t.Optional[str] = None + RAW_END: t.Optional[str] = None + __slots__ = ( - "time_mapping", - "time_trie", "pretty", - "quote_start", - "quote_end", - "identifier_start", - "identifier_end", - "bit_start", - "bit_end", - "hex_start", - "hex_end", - "byte_start", - "byte_end", - "raw_start", - "raw_end", "identify", "normalize", - "string_escape", - "identifier_escape", "pad", - "index_offset", - "unnest_column_only", - "alias_post_tablesample", - "identifiers_can_start_with_digit", + "_indent", "normalize_functions", "unsupported_level", - "unsupported_messages", - "null_ordering", "max_unsupported", - "_indent", + "leading_comma", + "max_text_width", + "comments", + "unsupported_messages", "_escaped_quote_end", "_escaped_identifier_end", - "_leading_comma", - "_max_text_width", - "_comments", "_cache", ) def __init__( self, - time_mapping=None, - time_trie=None, - pretty=None, - quote_start=None, - quote_end=None, - identifier_start=None, - identifier_end=None, - bit_start=None, - bit_end=None, - hex_start=None, - hex_end=None, - byte_start=None, - byte_end=None, - raw_start=None, - raw_end=None, - identify=False, - normalize=False, - string_escape=None, - identifier_escape=None, - pad=2, - indent=2, - index_offset=0, - unnest_column_only=False, - alias_post_tablesample=False, - identifiers_can_start_with_digit=False, - normalize_functions="upper", - unsupported_level=ErrorLevel.WARN, - null_ordering=None, - max_unsupported=3, - leading_comma=False, - max_text_width=80, - comments=True, + pretty: t.Optional[bool] = None, + identify: str | bool = False, + normalize: bool = False, + pad: int = 2, + indent: int = 2, + normalize_functions: t.Optional[str | bool] = None, + unsupported_level: ErrorLevel = ErrorLevel.WARN, + max_unsupported: int = 3, + leading_comma: bool = False, + max_text_width: int = 80, + comments: bool = True, ): import sqlglot - self.time_mapping = time_mapping or {} - self.time_trie = time_trie self.pretty = pretty if pretty is not None else sqlglot.pretty - self.quote_start = quote_start or "'" - self.quote_end = quote_end or "'" - self.identifier_start = identifier_start or '"' - self.identifier_end = identifier_end or '"' - self.bit_start = bit_start - self.bit_end = bit_end - self.hex_start = hex_start - self.hex_end = hex_end - self.byte_start = byte_start - self.byte_end = byte_end - self.raw_start = raw_start - self.raw_end = raw_end self.identify = identify self.normalize = normalize - self.string_escape = string_escape or "'" - self.identifier_escape = identifier_escape or '"' self.pad = pad - self.index_offset = index_offset - self.unnest_column_only = unnest_column_only - self.alias_post_tablesample = alias_post_tablesample - self.identifiers_can_start_with_digit = identifiers_can_start_with_digit - self.normalize_functions = normalize_functions + self._indent = indent self.unsupported_level = unsupported_level - self.unsupported_messages = [] self.max_unsupported = max_unsupported - self.null_ordering = null_ordering - self._indent = indent - self._escaped_quote_end = self.string_escape + self.quote_end - self._escaped_identifier_end = self.identifier_escape + self.identifier_end - self._leading_comma = leading_comma - self._max_text_width = max_text_width - self._comments = comments - self._cache = None + self.leading_comma = leading_comma + self.max_text_width = max_text_width + self.comments = comments + + # This is both a Dialect property and a Generator argument, so we prioritize the latter + self.normalize_functions = ( + self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions + ) + + self.unsupported_messages: t.List[str] = [] + self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END + self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END + self._cache: t.Optional[t.Dict[int, str]] = None def generate( self, @@ -364,17 +342,19 @@ class Generator: cache: t.Optional[t.Dict[int, str]] = None, ) -> str: """ - Generates a SQL string by interpreting the given syntax tree. + Generates the SQL string corresponding to the given syntax tree. - Args - expression: the syntax tree. - cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node. + Args: + expression: The syntax tree. + cache: An optional sql string cache. This leverages the hash of an Expression + which can be slow to compute, so only use it if you set _hash on each node. - Returns - the SQL string. + Returns: + The SQL string corresponding to `expression`. """ if cache is not None: self._cache = cache + self.unsupported_messages = [] sql = self.sql(expression).strip() self._cache = None @@ -414,7 +394,11 @@ class Generator: expression: t.Optional[exp.Expression] = None, comments: t.Optional[t.List[str]] = None, ) -> str: - comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore + comments = ( + ((expression and expression.comments) if comments is None else comments) # type: ignore + if self.comments + else None + ) if not comments or isinstance(expression, exp.Binary): return sql @@ -454,7 +438,7 @@ class Generator: return result def normalize_func(self, name: str) -> str: - if self.normalize_functions == "upper": + if self.normalize_functions == "upper" or self.normalize_functions is True: return name.upper() if self.normalize_functions == "lower": return name.lower() @@ -522,7 +506,7 @@ class Generator: else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - sql = self.maybe_comment(sql, expression) if self._comments and comment else sql + sql = self.maybe_comment(sql, expression) if self.comments and comment else sql if self._cache is not None: self._cache[expression_id] = sql @@ -770,25 +754,25 @@ class Generator: def bitstring_sql(self, expression: exp.BitString) -> str: this = self.sql(expression, "this") - if self.bit_start: - return f"{self.bit_start}{this}{self.bit_end}" + if self.BIT_START: + return f"{self.BIT_START}{this}{self.BIT_END}" return f"{int(this, 2)}" def hexstring_sql(self, expression: exp.HexString) -> str: this = self.sql(expression, "this") - if self.hex_start: - return f"{self.hex_start}{this}{self.hex_end}" + if self.HEX_START: + return f"{self.HEX_START}{this}{self.HEX_END}" return f"{int(this, 16)}" def bytestring_sql(self, expression: exp.ByteString) -> str: this = self.sql(expression, "this") - if self.byte_start: - return f"{self.byte_start}{this}{self.byte_end}" + if self.BYTE_START: + return f"{self.BYTE_START}{this}{self.BYTE_END}" return this def rawstring_sql(self, expression: exp.RawString) -> str: - if self.raw_start: - return f"{self.raw_start}{expression.name}{self.raw_end}" + if self.RAW_START: + return f"{self.RAW_START}{expression.name}{self.RAW_END}" return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\"))) def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: @@ -883,24 +867,27 @@ class Generator: name = f"{expression.name} " if expression.name else "" table = self.sql(expression, "table") table = f"{self.INDEX_ON} {table} " if table else "" + using = self.sql(expression, "using") + using = f"USING {using} " if using else "" index = "INDEX " if not table else "" columns = self.expressions(expression, key="columns", flat=True) + columns = f"({columns})" if columns else "" partition_by = self.expressions(expression, key="partition_by", flat=True) partition_by = f" PARTITION BY {partition_by}" if partition_by else "" - return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}" + return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name lower = text.lower() text = lower if self.normalize and not expression.quoted else text - text = text.replace(self.identifier_end, self._escaped_identifier_end) + text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end) if ( expression.quoted or should_identify(text, self.identify) or lower in self.RESERVED_KEYWORDS - or (not self.identifiers_can_start_with_digit and text[:1].isdigit()) + or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit()) ): - text = f"{self.identifier_start}{text}{self.identifier_end}" + text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}" return text def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: @@ -1197,7 +1184,7 @@ class Generator: def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: - if self.alias_post_tablesample and expression.this.alias: + if self.ALIAS_POST_TABLESAMPLE and expression.this.alias: table = expression.this.copy() table.set("alias", None) this = self.sql(table) @@ -1372,7 +1359,15 @@ class Generator: def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") - return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" + args = ", ".join( + sql + for sql in ( + self.sql(expression, "offset"), + self.sql(expression, "expression"), + ) + if sql + ) + return f"{this}{self.seg('LIMIT')} {args}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") @@ -1418,10 +1413,10 @@ class Generator: def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: - text = text.replace(self.quote_end, self._escaped_quote_end) + text = text.replace(self.QUOTE_END, self._escaped_quote_end) if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) - text = f"{self.quote_start}{text}{self.quote_end}" + text = f"{self.QUOTE_START}{text}{self.QUOTE_END}" return text def loaddata_sql(self, expression: exp.LoadData) -> str: @@ -1463,9 +1458,9 @@ class Generator: nulls_first = expression.args.get("nulls_first") nulls_last = not nulls_first - nulls_are_large = self.null_ordering == "nulls_are_large" - nulls_are_small = self.null_ordering == "nulls_are_small" - nulls_are_last = self.null_ordering == "nulls_are_last" + nulls_are_large = self.NULL_ORDERING == "nulls_are_large" + nulls_are_small = self.NULL_ORDERING == "nulls_are_small" + nulls_are_last = self.NULL_ORDERING == "nulls_are_last" sort_order = " DESC" if desc else "" nulls_sort_change = "" @@ -1521,7 +1516,7 @@ class Generator: return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: - limit = expression.args.get("limit") + limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit") if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): limit = exp.Limit(expression=limit.args.get("count")) @@ -1540,12 +1535,19 @@ class Generator: self.sql(expression, "having"), *self.after_having_modifiers(expression), self.sql(expression, "order"), - self.sql(expression, "offset") if fetch else self.sql(limit), - self.sql(limit) if fetch else self.sql(expression, "offset"), + *self.offset_limit_modifiers(expression, fetch, limit), *self.after_limit_modifiers(expression), sep="", ) + def offset_limit_modifiers( + self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] + ) -> t.List[str]: + return [ + self.sql(expression, "offset") if fetch else self.sql(limit), + self.sql(limit) if fetch else self.sql(expression, "offset"), + ] + def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: return [ self.sql(expression, "qualify"), @@ -1634,7 +1636,7 @@ class Generator: def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) alias = expression.args.get("alias") - if alias and self.unnest_column_only: + if alias and self.UNNEST_COLUMN_ONLY: columns = alias.columns alias = self.sql(columns[0]) if columns else "" else: @@ -1697,7 +1699,7 @@ class Generator: return f"{this} BETWEEN {low} AND {high}" def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset) + expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET) expressions_sql = ", ".join(self.sql(e) for e in expressions) return f"{self.sql(expression, 'this')}[{expressions_sql}]" @@ -1729,7 +1731,7 @@ class Generator: statements.append("END") - if self.pretty and self.text_width(statements) > self._max_text_width: + if self.pretty and self.text_width(statements) > self.max_text_width: return self.indent("\n".join(statements), skip_first=True, skip_last=True) return " ".join(statements) @@ -1759,10 +1761,11 @@ class Generator: else: return self.func("TRIM", expression.this, expression.expression) - def concat_sql(self, expression: exp.Concat) -> str: - if len(expression.expressions) == 1: - return self.sql(expression.expressions[0]) - return self.function_fallback_sql(expression) + def safeconcat_sql(self, expression: exp.SafeConcat) -> str: + expressions = expression.expressions + if self.STRICT_STRING_CONCAT: + expressions = (exp.cast(e, "text") for e in expressions) + return self.func("CONCAT", *expressions) def check_sql(self, expression: exp.Check) -> str: this = self.sql(expression, key="this") @@ -1785,9 +1788,7 @@ class Generator: return f"PRIMARY KEY ({expressions}){options}" def if_sql(self, expression: exp.If) -> str: - return self.case_sql( - exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) - ) + return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: modifier = expression.args.get("modifier") @@ -1798,7 +1799,6 @@ class Generator: return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}" def jsonobject_sql(self, expression: exp.JSONObject) -> str: - expressions = self.expressions(expression) null_handling = expression.args.get("null_handling") null_handling = f" {null_handling}" if null_handling else "" unique_keys = expression.args.get("unique_keys") @@ -1811,7 +1811,11 @@ class Generator: format_json = " FORMAT JSON" if expression.args.get("format_json") else "" encoding = self.sql(expression, "encoding") encoding = f" ENCODING {encoding}" if encoding else "" - return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})" + return self.func( + "JSON_OBJECT", + *expression.expressions, + suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})", + ) def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: this = self.sql(expression, "this") @@ -1930,7 +1934,7 @@ class Generator: for i, e in enumerate(expression.flatten(unnest=False)) ) - sep = "\n" if self.text_width(sqls) > self._max_text_width else " " + sep = "\n" if self.text_width(sqls) > self.max_text_width else " " return f"{sep}{op} ".join(sqls) def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: @@ -2093,6 +2097,11 @@ class Generator: def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") + def safedpipe_sql(self, expression: exp.SafeDPipe) -> str: + if self.STRICT_STRING_CONCAT: + return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten())) + return self.dpipe_sql(expression) + def div_sql(self, expression: exp.Div) -> str: return self.binary(expression, "/") @@ -2127,7 +2136,7 @@ class Generator: return self.binary(expression, "ILIKE ANY") def is_sql(self, expression: exp.Is) -> str: - if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean): + if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean): return self.sql( expression.this if expression.expression.this else exp.not_(expression.this) ) @@ -2197,12 +2206,18 @@ class Generator: return self.func(expression.sql_name(), *args) - def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str: - return f"{self.normalize_func(name)}({self.format_args(*args)})" + def func( + self, + name: str, + *args: t.Optional[exp.Expression | str], + prefix: str = "(", + suffix: str = ")", + ) -> str: + return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}" def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) - if self.pretty and self.text_width(arg_sqls) > self._max_text_width: + if self.pretty and self.text_width(arg_sqls) > self.max_text_width: return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) return ", ".join(arg_sqls) @@ -2210,7 +2225,9 @@ class Generator: return sum(len(arg) for arg in args) def format_time(self, expression: exp.Expression) -> t.Optional[str]: - return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) + return format_time( + self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE + ) def expressions( self, @@ -2242,7 +2259,7 @@ class Generator: comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" if self.pretty: - if self._leading_comma: + if self.leading_comma: result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}") else: result_sqls.append( -- cgit v1.2.3