summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:18 +0000
commit67578a7602a5be7eb51f324086c8d49bcf8b7498 (patch)
tree0b7515c922d1c383cea24af5175379cfc8edfd15 /sqlglot/generator.py
parentReleasing debian version 15.2.0-1. (diff)
downloadsqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.tar.xz
sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.zip
Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py383
1 files changed, 200 insertions, 183 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 97cbe15..d3cf9f0 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot")
class Generator:
"""
- Generator interprets the given syntax tree and produces a SQL string as an output.
+ Generator converts a given syntax tree to the corresponding SQL string.
Args:
- time_mapping (dict): the dictionary of custom time mappings in which the key
- represents a python time format and the output the target time format
- time_trie (trie): a trie of the time_mapping keys
- pretty (bool): if set to True the returned string will be formatted. Default: False.
- quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
- quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
- identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
- identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
- bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
- bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
- hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
- hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
- byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
- byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
- raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
- raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
- identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
- normalize (bool): if set to True all identifiers will lower cased
- string_escape (str): specifies a string escape character. Default: '.
- identifier_escape (str): specifies an identifier escape character. Default: ".
- pad (int): determines padding in a formatted string. Default: 2.
- indent (int): determines the size of indentation in a formatted string. Default: 4.
- unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
- normalize_functions (str): normalize function names, "upper", "lower", or None
- Default: "upper"
- alias_post_tablesample (bool): if the table alias comes after tablesample
- Default: False
- identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
- Default: False
- unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
- unsupported expressions. Default ErrorLevel.WARN.
- null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
- Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
- Default: "nulls_are_small"
- max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
+ pretty: Whether or not to format the produced SQL string.
+ Default: False.
+ identify: Determines when an identifier should be quoted. Possible values are:
+ False (default): Never quote, except in cases where it's mandatory by the dialect.
+ True or 'always': Always quote.
+ 'safe': Only quote identifiers that are case insensitive.
+ normalize: Whether or not to normalize identifiers to lowercase.
+ Default: False.
+ pad: Determines the pad size in a formatted string.
+ Default: 2.
+ indent: Determines the indentation size in a formatted string.
+ Default: 2.
+ normalize_functions: Whether or not to normalize all function names. Possible values are:
+ "upper" or True (default): Convert names to uppercase.
+ "lower": Convert names to lowercase.
+ False: Disables function name normalization.
+ unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
+ Default ErrorLevel.WARN.
+ max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
- leading_comma (bool): if the the comma is leading or trailing in select statements
+ leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
+ This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
@@ -86,6 +71,7 @@ class Generator:
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
+ exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
@@ -138,15 +124,24 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
- # Whether a table is allowed to be renamed with a db
+ # Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
- # The string used for creating index on a table
+ # The string used for creating an index on a table
INDEX_ON = "ON"
+ # Whether or not join hints should be generated
+ JOIN_HINTS = True
+
+ # Whether or not table hints should be generated
+ TABLE_HINTS = True
+
+ # Whether or not comparing against booleans (e.g. x IS TRUE) is supported
+ IS_BOOL_ALLOWED = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -228,6 +223,7 @@ class Generator:
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
+ exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
@@ -235,128 +231,110 @@ class Generator:
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
- JOIN_HINTS = True
- TABLE_HINTS = True
- IS_BOOL = True
-
+ # Keywords that can't be used as unquoted identifier names
RESERVED_KEYWORDS: t.Set[str] = set()
- WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
- UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
+
+ # Expressions whose comments are separated from them for better formatting
+ WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Select,
+ exp.From,
+ exp.Where,
+ exp.With,
+ )
+
+ # Expressions that can remain unwrapped when appearing in the context of an INTERVAL
+ UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Column,
+ exp.Literal,
+ exp.Neg,
+ exp.Paren,
+ )
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
+ # Autofilled
+ INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
+ INVERSE_TIME_TRIE: t.Dict = {}
+ INDEX_OFFSET = 0
+ UNNEST_COLUMN_ONLY = False
+ ALIAS_POST_TABLESAMPLE = False
+ IDENTIFIERS_CAN_START_WITH_DIGIT = False
+ STRICT_STRING_CONCAT = False
+ NORMALIZE_FUNCTIONS: bool | str = "upper"
+ NULL_ORDERING = "nulls_are_small"
+
+ # Delimiters for quotes, identifiers and the corresponding escape characters
+ QUOTE_START = "'"
+ QUOTE_END = "'"
+ IDENTIFIER_START = '"'
+ IDENTIFIER_END = '"'
+ STRING_ESCAPE = "'"
+ IDENTIFIER_ESCAPE = '"'
+
+ # Delimiters for bit, hex, byte and raw literals
+ BIT_START: t.Optional[str] = None
+ BIT_END: t.Optional[str] = None
+ HEX_START: t.Optional[str] = None
+ HEX_END: t.Optional[str] = None
+ BYTE_START: t.Optional[str] = None
+ BYTE_END: t.Optional[str] = None
+ RAW_START: t.Optional[str] = None
+ RAW_END: t.Optional[str] = None
+
__slots__ = (
- "time_mapping",
- "time_trie",
"pretty",
- "quote_start",
- "quote_end",
- "identifier_start",
- "identifier_end",
- "bit_start",
- "bit_end",
- "hex_start",
- "hex_end",
- "byte_start",
- "byte_end",
- "raw_start",
- "raw_end",
"identify",
"normalize",
- "string_escape",
- "identifier_escape",
"pad",
- "index_offset",
- "unnest_column_only",
- "alias_post_tablesample",
- "identifiers_can_start_with_digit",
+ "_indent",
"normalize_functions",
"unsupported_level",
- "unsupported_messages",
- "null_ordering",
"max_unsupported",
- "_indent",
+ "leading_comma",
+ "max_text_width",
+ "comments",
+ "unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
- "_leading_comma",
- "_max_text_width",
- "_comments",
"_cache",
)
def __init__(
self,
- time_mapping=None,
- time_trie=None,
- pretty=None,
- quote_start=None,
- quote_end=None,
- identifier_start=None,
- identifier_end=None,
- bit_start=None,
- bit_end=None,
- hex_start=None,
- hex_end=None,
- byte_start=None,
- byte_end=None,
- raw_start=None,
- raw_end=None,
- identify=False,
- normalize=False,
- string_escape=None,
- identifier_escape=None,
- pad=2,
- indent=2,
- index_offset=0,
- unnest_column_only=False,
- alias_post_tablesample=False,
- identifiers_can_start_with_digit=False,
- normalize_functions="upper",
- unsupported_level=ErrorLevel.WARN,
- null_ordering=None,
- max_unsupported=3,
- leading_comma=False,
- max_text_width=80,
- comments=True,
+ pretty: t.Optional[bool] = None,
+ identify: str | bool = False,
+ normalize: bool = False,
+ pad: int = 2,
+ indent: int = 2,
+ normalize_functions: t.Optional[str | bool] = None,
+ unsupported_level: ErrorLevel = ErrorLevel.WARN,
+ max_unsupported: int = 3,
+ leading_comma: bool = False,
+ max_text_width: int = 80,
+ comments: bool = True,
):
import sqlglot
- self.time_mapping = time_mapping or {}
- self.time_trie = time_trie
self.pretty = pretty if pretty is not None else sqlglot.pretty
- self.quote_start = quote_start or "'"
- self.quote_end = quote_end or "'"
- self.identifier_start = identifier_start or '"'
- self.identifier_end = identifier_end or '"'
- self.bit_start = bit_start
- self.bit_end = bit_end
- self.hex_start = hex_start
- self.hex_end = hex_end
- self.byte_start = byte_start
- self.byte_end = byte_end
- self.raw_start = raw_start
- self.raw_end = raw_end
self.identify = identify
self.normalize = normalize
- self.string_escape = string_escape or "'"
- self.identifier_escape = identifier_escape or '"'
self.pad = pad
- self.index_offset = index_offset
- self.unnest_column_only = unnest_column_only
- self.alias_post_tablesample = alias_post_tablesample
- self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
- self.normalize_functions = normalize_functions
+ self._indent = indent
self.unsupported_level = unsupported_level
- self.unsupported_messages = []
self.max_unsupported = max_unsupported
- self.null_ordering = null_ordering
- self._indent = indent
- self._escaped_quote_end = self.string_escape + self.quote_end
- self._escaped_identifier_end = self.identifier_escape + self.identifier_end
- self._leading_comma = leading_comma
- self._max_text_width = max_text_width
- self._comments = comments
- self._cache = None
+ self.leading_comma = leading_comma
+ self.max_text_width = max_text_width
+ self.comments = comments
+
+ # This is both a Dialect property and a Generator argument, so we prioritize the latter
+ self.normalize_functions = (
+ self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
+ )
+
+ self.unsupported_messages: t.List[str] = []
+ self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
+ self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
+ self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
@@ -364,17 +342,19 @@ class Generator:
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
"""
- Generates a SQL string by interpreting the given syntax tree.
+ Generates the SQL string corresponding to the given syntax tree.
- Args
- expression: the syntax tree.
- cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
+ Args:
+ expression: The syntax tree.
+ cache: An optional sql string cache. This leverages the hash of an Expression
+ which can be slow to compute, so only use it if you set _hash on each node.
- Returns
- the SQL string.
+ Returns:
+ The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
+
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
@@ -414,7 +394,11 @@ class Generator:
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
- comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
+ comments = (
+ ((expression and expression.comments) if comments is None else comments) # type: ignore
+ if self.comments
+ else None
+ )
if not comments or isinstance(expression, exp.Binary):
return sql
@@ -454,7 +438,7 @@ class Generator:
return result
def normalize_func(self, name: str) -> str:
- if self.normalize_functions == "upper":
+ if self.normalize_functions == "upper" or self.normalize_functions is True:
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
@@ -522,7 +506,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
+ sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
@@ -770,25 +754,25 @@ class Generator:
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
- if self.bit_start:
- return f"{self.bit_start}{this}{self.bit_end}"
+ if self.BIT_START:
+ return f"{self.BIT_START}{this}{self.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
- if self.hex_start:
- return f"{self.hex_start}{this}{self.hex_end}"
+ if self.HEX_START:
+ return f"{self.HEX_START}{this}{self.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
- if self.byte_start:
- return f"{self.byte_start}{this}{self.byte_end}"
+ if self.BYTE_START:
+ return f"{self.BYTE_START}{this}{self.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
- if self.raw_start:
- return f"{self.raw_start}{expression.name}{self.raw_end}"
+ if self.RAW_START:
+ return f"{self.RAW_START}{expression.name}{self.RAW_END}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
@@ -883,24 +867,27 @@ class Generator:
name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table} " if table else ""
+ using = self.sql(expression, "using")
+ using = f"USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
+ columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
- return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
+ return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
- text = text.replace(self.identifier_end, self._escaped_identifier_end)
+ text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
- or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
+ or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
- text = f"{self.identifier_start}{text}{self.identifier_end}"
+ text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
@@ -1197,7 +1184,7 @@ class Generator:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
- if self.alias_post_tablesample and expression.this.alias:
+ if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
@@ -1372,7 +1359,15 @@ class Generator:
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
- return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
+ args = ", ".join(
+ sql
+ for sql in (
+ self.sql(expression, "offset"),
+ self.sql(expression, "expression"),
+ )
+ if sql
+ )
+ return f"{this}{self.seg('LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@@ -1418,10 +1413,10 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
- text = text.replace(self.quote_end, self._escaped_quote_end)
+ text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
- text = f"{self.quote_start}{text}{self.quote_end}"
+ text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
return text
def loaddata_sql(self, expression: exp.LoadData) -> str:
@@ -1463,9 +1458,9 @@ class Generator:
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
- nulls_are_large = self.null_ordering == "nulls_are_large"
- nulls_are_small = self.null_ordering == "nulls_are_small"
- nulls_are_last = self.null_ordering == "nulls_are_last"
+ nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
+ nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
+ nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
@@ -1521,7 +1516,7 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
- limit = expression.args.get("limit")
+ limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=limit.args.get("count"))
@@ -1540,12 +1535,19 @@ class Generator:
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
- self.sql(expression, "offset") if fetch else self.sql(limit),
- self.sql(limit) if fetch else self.sql(expression, "offset"),
+ *self.offset_limit_modifiers(expression, fetch, limit),
*self.after_limit_modifiers(expression),
sep="",
)
+ def offset_limit_modifiers(
+ self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
+ ) -> t.List[str]:
+ return [
+ self.sql(expression, "offset") if fetch else self.sql(limit),
+ self.sql(limit) if fetch else self.sql(expression, "offset"),
+ ]
+
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
@@ -1634,7 +1636,7 @@ class Generator:
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
- if alias and self.unnest_column_only:
+ if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
@@ -1697,7 +1699,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
- expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
+ expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@@ -1729,7 +1731,7 @@ class Generator:
statements.append("END")
- if self.pretty and self.text_width(statements) > self._max_text_width:
+ if self.pretty and self.text_width(statements) > self.max_text_width:
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
return " ".join(statements)
@@ -1759,10 +1761,11 @@ class Generator:
else:
return self.func("TRIM", expression.this, expression.expression)
- def concat_sql(self, expression: exp.Concat) -> str:
- if len(expression.expressions) == 1:
- return self.sql(expression.expressions[0])
- return self.function_fallback_sql(expression)
+ def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
+ expressions = expression.expressions
+ if self.STRICT_STRING_CONCAT:
+ expressions = (exp.cast(e, "text") for e in expressions)
+ return self.func("CONCAT", *expressions)
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
@@ -1785,9 +1788,7 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
- return self.case_sql(
- exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
- )
+ return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
modifier = expression.args.get("modifier")
@@ -1798,7 +1799,6 @@ class Generator:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
- expressions = self.expressions(expression)
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
@@ -1811,7 +1811,11 @@ class Generator:
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
- return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
+ return self.func(
+ "JSON_OBJECT",
+ *expression.expressions,
+ suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})",
+ )
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
this = self.sql(expression, "this")
@@ -1930,7 +1934,7 @@ class Generator:
for i, e in enumerate(expression.flatten(unnest=False))
)
- sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
+ sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
return f"{sep}{op} ".join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
@@ -2093,6 +2097,11 @@ class Generator:
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
+ def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
+ if self.STRICT_STRING_CONCAT:
+ return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
+ return self.dpipe_sql(expression)
+
def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
@@ -2127,7 +2136,7 @@ class Generator:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
- if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
+ if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean):
return self.sql(
expression.this if expression.expression.this else exp.not_(expression.this)
)
@@ -2197,12 +2206,18 @@ class Generator:
return self.func(expression.sql_name(), *args)
- def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
- return f"{self.normalize_func(name)}({self.format_args(*args)})"
+ def func(
+ self,
+ name: str,
+ *args: t.Optional[exp.Expression | str],
+ prefix: str = "(",
+ suffix: str = ")",
+ ) -> str:
+ return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
- if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
+ if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
@@ -2210,7 +2225,9 @@ class Generator:
return sum(len(arg) for arg in args)
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
- return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
+ return format_time(
+ self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
+ )
def expressions(
self,
@@ -2242,7 +2259,7 @@ class Generator:
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
- if self._leading_comma:
+ if self.leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(