From 3d48060515ba25b4c49d975a520ee0682327d1b7 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 16 Feb 2024 06:45:52 +0100 Subject: Merging upstream version 21.1.1. Signed-off-by: Daniel Baumann --- sqlglot/transforms.py | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) (limited to 'sqlglot/transforms.py') diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index f13569f..4777609 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -213,6 +213,19 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp is_posexplode = isinstance(explode, exp.Posexplode) explode_arg = explode.this + if isinstance(explode, exp.ExplodeOuter): + bracket = explode_arg[0] + bracket.set("safe", True) + bracket.set("offset", True) + explode_arg = exp.func( + "IF", + exp.func( + "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) + ).eq(0), + exp.array(bracket, copy=False), + explode_arg, + ) + # This ensures that we won't use [POS]EXPLODE's argument as a new selection if isinstance(explode_arg, exp.Column): taken_select_names.add(explode_arg.output_name) @@ -466,6 +479,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression: return expression +def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + for constraint in expression.find_all(exp.UniqueColumnConstraint): + if constraint.parent: + constraint.parent.pop() + + return expression + + +def ctas_with_tmp_tables_to_create_tmp_view( + expression: exp.Expression, + tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, +) -> exp.Expression: + assert isinstance(expression, exp.Create) + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + # CTAS with temp tables map to CREATE TEMPORARY VIEW + if expression.kind == "TABLE" and temporary: + if expression.expression: + return exp.Create( + kind="TEMPORARY VIEW", + this=expression.this, + expression=expression.expression, + ) + return tmp_storage_provider(expression) + + return expression + + +def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: + """ + In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the + PARTITIONED BY value is an array of column names, they are transformed into a schema. + The corresponding columns are removed from the create statement. + """ + assert isinstance(expression, exp.Create) + has_schema = isinstance(expression.this, exp.Schema) + is_partitionable = expression.kind in {"TABLE", "VIEW"} + + if has_schema and is_partitionable: + prop = expression.find(exp.PartitionedByProperty) + if prop and prop.this and not isinstance(prop.this, exp.Schema): + schema = expression.this + columns = {v.name.upper() for v in prop.this.expressions} + partitions = [col for col in schema.expressions if col.name.upper() in columns] + schema.set("expressions", [e for e in schema.expressions if e not in partitions]) + prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) + expression.set("this", schema) + + return expression + + +def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. + + Currently, SQLGlot uses the DATASOURCE format for Spark 3. + """ + assert isinstance(expression, exp.Create) + prop = expression.find(exp.PartitionedByProperty) + if ( + prop + and prop.this + and isinstance(prop.this, exp.Schema) + and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions) + ): + prop_this = exp.Tuple( + expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] + ) + schema = expression.this + for e in prop.this.expressions: + schema.append("expressions", e) + prop.set("this", prop_this) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: -- cgit v1.2.3