diff options
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r-- | sqlglot/dataframe/sql/_typing.py (renamed from sqlglot/dataframe/sql/_typing.pyi) | 4 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 12 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/operations.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 20 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/util.py | 2 |
5 files changed, 18 insertions, 22 deletions
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.py index 1682ec1..fb46026 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.py @@ -11,6 +11,8 @@ if t.TYPE_CHECKING: ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] ColumnOrName = t.Union[Column, str] -ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +ColumnOrLiteral = t.Union[ + Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime +] SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index f3a6f6f..3fc9232 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -127,7 +127,7 @@ class DataFrame: sequence_id: t.Optional[str] = None, **kwargs, ) -> t.Tuple[exp.CTE, str]: - name = self.spark._random_name + name = self._create_hash_from_expression(expression) expression_to_cte = expression.copy() expression_to_cte.set("with", None) cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] @@ -263,7 +263,7 @@ class DataFrame: return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] @classmethod - def _create_hash_from_expression(cls, expression: exp.Select): + def _create_hash_from_expression(cls, expression: exp.Expression) -> str: value = expression.sql(dialect="spark").encode("utf-8") return f"t{zlib.crc32(value)}"[:6] @@ -299,7 +299,7 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = optimize_func(select_expression, identify="always") + select_expression = t.cast(exp.Select, optimize_func(select_expression)) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: @@ -570,9 +570,9 @@ class DataFrame: r_expressions.append(l_column) r_columns_unused.remove(l_column) else: - r_expressions.append(exp.alias_(exp.Null(), l_column)) + r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) for r_column in r_columns_unused: - l_expressions.append(exp.alias_(exp.Null(), r_column)) + l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) r_expressions.append(r_column) r_df = ( other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) @@ -761,7 +761,7 @@ class DataFrame: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: if isinstance(existing_column, exp.Column): - existing_column.replace(exp.alias_(existing_column.copy(), new)) + existing_column.replace(exp.alias_(existing_column, new)) else: existing_column.set("alias", exp.to_identifier(new)) return self.copy(expression=expression) diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py index d51335c..e4c106b 100644 --- a/sqlglot/dataframe/sql/operations.py +++ b/sqlglot/dataframe/sql/operations.py @@ -41,7 +41,7 @@ def operation(op: Operation): self.last_op = Operation.NO_OP last_op = self.last_op new_op = op if op != Operation.NO_OP else last_op - if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT): + if new_op < last_op or (last_op == new_op == Operation.SELECT): self = self._convert_leaf_to_cte() df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) df.last_op = new_op # type: ignore diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index af589b0..b883359 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -87,15 +87,13 @@ class SparkSession: select_kwargs = { "expressions": sel_columns, "from": exp.From( - expressions=[ - exp.Values( - expressions=data_expressions, - alias=exp.TableAlias( - this=exp.to_identifier(self._auto_incrementing_name), - columns=[exp.to_identifier(col_name) for col_name in column_mapping], - ), + this=exp.Values( + expressions=data_expressions, + alias=exp.TableAlias( + this=exp.to_identifier(self._auto_incrementing_name), + columns=[exp.to_identifier(col_name) for col_name in column_mapping], ), - ], + ), ), } @@ -128,10 +126,6 @@ class SparkSession: return name @property - def _random_name(self) -> str: - return "r" + uuid.uuid4().hex - - @property def _random_branch_id(self) -> str: id = self._random_id self.known_branch_ids.add(id) @@ -145,7 +139,7 @@ class SparkSession: @property def _random_id(self) -> str: - id = self._random_name + id = "r" + uuid.uuid4().hex self.known_ids.add(id) return id diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py index 575d18a..4b9fbb1 100644 --- a/sqlglot/dataframe/sql/util.py +++ b/sqlglot/dataframe/sql/util.py @@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T if not expression.args.get("joins"): return [] - left_table = expression.args["from"].args["expressions"][0] + left_table = expression.args["from"].this other_tables = [join.this for join in expression.args["joins"]] return [left_table] + other_tables |