diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 79 |
1 files changed, 67 insertions, 12 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 857eff1..40ba88e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -164,6 +164,11 @@ class Generator: # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") + # Whether or not VALUES statements can be used as derived tables. + # MySQL 5 and Redshift do not allow this, so when False, it will convert + # SELECT * VALUES into SELECT UNION + VALUES_AS_TABLE = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -260,8 +265,9 @@ class Generator: # Expressions whose comments are separated from them for better formatting WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( - exp.Select, + exp.Drop, exp.From, + exp.Select, exp.Where, exp.With, ) @@ -818,7 +824,11 @@ class Generator: def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this - type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) + type_sql = ( + self.TYPE_MAPPING.get(type_value, type_value.value) + if isinstance(type_value, exp.DataType.Type) + else type_value + ) nested = "" interior = self.expressions(expression, flat=True) values = "" @@ -1307,15 +1317,45 @@ class Generator: return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: - args = self.expressions(expression) - alias = self.sql(expression, "alias") - values = f"VALUES{self.seg('')}{args}" - values = ( - f"({values})" - if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From)) - else values - ) - return f"{values} AS {alias}" if alias else values + # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example + if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join): + args = self.expressions(expression) + alias = self.sql(expression, "alias") + values = f"VALUES{self.seg('')}{args}" + values = ( + f"({values})" + if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From)) + else values + ) + return f"{values} AS {alias}" if alias else values + + # Converts `VALUES...` expression into a series of select unions. + # Note: If you have a lot of unions then this will result in a large number of recursive statements to + # evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be + # very slow. + expression = expression.copy() + column_names = expression.alias and expression.args["alias"].columns + + selects = [] + + for i, tup in enumerate(expression.expressions): + row = tup.expressions + + if i == 0 and column_names: + row = [ + exp.alias_(value, column_name) for value, column_name in zip(row, column_names) + ] + + selects.append(exp.Select(expressions=row)) + + subquery_expression: exp.Select | exp.Union = selects[0] + if len(selects) > 1: + for select in selects[1:]: + subquery_expression = exp.union( + subquery_expression, select, distinct=False, copy=False + ) + + return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False)) def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") @@ -2043,7 +2083,7 @@ class Generator: def and_sql(self, expression: exp.And) -> str: return self.connector_sql(expression, "AND") - def xor_sql(self, expression: exp.And) -> str: + def xor_sql(self, expression: exp.Xor) -> str: return self.connector_sql(expression, "XOR") def connector_sql(self, expression: exp.Connector, op: str) -> str: @@ -2507,6 +2547,21 @@ class Generator: return self.func("ANY_VALUE", this) + def querytransform_sql(self, expression: exp.QueryTransform) -> str: + transform = self.func("TRANSFORM", *expression.expressions) + row_format_before = self.sql(expression, "row_format_before") + row_format_before = f" {row_format_before}" if row_format_before else "" + record_writer = self.sql(expression, "record_writer") + record_writer = f" RECORDWRITER {record_writer}" if record_writer else "" + using = f" USING {self.sql(expression, 'command_script')}" + schema = self.sql(expression, "schema") + schema = f" AS {schema}" if schema else "" + row_format_after = self.sql(expression, "row_format_after") + row_format_after = f" {row_format_after}" if row_format_after else "" + record_reader = self.sql(expression, "record_reader") + record_reader = f" RECORDREADER {record_reader}" if record_reader else "" + return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None |