From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- sqlglot/dataframe/sql/_typing.py | 18 ++++++++++++++++++ sqlglot/dataframe/sql/_typing.pyi | 16 ---------------- sqlglot/dataframe/sql/dataframe.py | 12 ++++++------ sqlglot/dataframe/sql/operations.py | 2 +- sqlglot/dataframe/sql/session.py | 20 +++++++------------- sqlglot/dataframe/sql/util.py | 2 +- 6 files changed, 33 insertions(+), 37 deletions(-) create mode 100644 sqlglot/dataframe/sql/_typing.py delete mode 100644 sqlglot/dataframe/sql/_typing.pyi (limited to 'sqlglot/dataframe') diff --git a/sqlglot/dataframe/sql/_typing.py b/sqlglot/dataframe/sql/_typing.py new file mode 100644 index 0000000..fb46026 --- /dev/null +++ b/sqlglot/dataframe/sql/_typing.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import datetime +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.column import Column + from sqlglot.dataframe.sql.types import StructType + +ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +ColumnOrName = t.Union[Column, str] +ColumnOrLiteral = t.Union[ + Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime +] +SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] +OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi deleted file mode 100644 index 1682ec1..0000000 --- a/sqlglot/dataframe/sql/_typing.pyi +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -import datetime -import typing as t - -from sqlglot import expressions as exp - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.column import Column - from sqlglot.dataframe.sql.types import StructType - -ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] -ColumnOrName = t.Union[Column, str] -ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] -SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] -OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index f3a6f6f..3fc9232 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -127,7 +127,7 @@ class DataFrame: sequence_id: t.Optional[str] = None, **kwargs, ) -> t.Tuple[exp.CTE, str]: - name = self.spark._random_name + name = self._create_hash_from_expression(expression) expression_to_cte = expression.copy() expression_to_cte.set("with", None) cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] @@ -263,7 +263,7 @@ class DataFrame: return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] @classmethod - def _create_hash_from_expression(cls, expression: exp.Select): + def _create_hash_from_expression(cls, expression: exp.Expression) -> str: value = expression.sql(dialect="spark").encode("utf-8") return f"t{zlib.crc32(value)}"[:6] @@ -299,7 +299,7 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = optimize_func(select_expression, identify="always") + select_expression = t.cast(exp.Select, optimize_func(select_expression)) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: @@ -570,9 +570,9 @@ class DataFrame: r_expressions.append(l_column) r_columns_unused.remove(l_column) else: - r_expressions.append(exp.alias_(exp.Null(), l_column)) + r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) for r_column in r_columns_unused: - l_expressions.append(exp.alias_(exp.Null(), r_column)) + l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) r_expressions.append(r_column) r_df = ( other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) @@ -761,7 +761,7 @@ class DataFrame: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: if isinstance(existing_column, exp.Column): - existing_column.replace(exp.alias_(existing_column.copy(), new)) + existing_column.replace(exp.alias_(existing_column, new)) else: existing_column.set("alias", exp.to_identifier(new)) return self.copy(expression=expression) diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py index d51335c..e4c106b 100644 --- a/sqlglot/dataframe/sql/operations.py +++ b/sqlglot/dataframe/sql/operations.py @@ -41,7 +41,7 @@ def operation(op: Operation): self.last_op = Operation.NO_OP last_op = self.last_op new_op = op if op != Operation.NO_OP else last_op - if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT): + if new_op < last_op or (last_op == new_op == Operation.SELECT): self = self._convert_leaf_to_cte() df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) df.last_op = new_op # type: ignore diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index af589b0..b883359 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -87,15 +87,13 @@ class SparkSession: select_kwargs = { "expressions": sel_columns, "from": exp.From( - expressions=[ - exp.Values( - expressions=data_expressions, - alias=exp.TableAlias( - this=exp.to_identifier(self._auto_incrementing_name), - columns=[exp.to_identifier(col_name) for col_name in column_mapping], - ), + this=exp.Values( + expressions=data_expressions, + alias=exp.TableAlias( + this=exp.to_identifier(self._auto_incrementing_name), + columns=[exp.to_identifier(col_name) for col_name in column_mapping], ), - ], + ), ), } @@ -127,10 +125,6 @@ class SparkSession: self.incrementing_id += 1 return name - @property - def _random_name(self) -> str: - return "r" + uuid.uuid4().hex - @property def _random_branch_id(self) -> str: id = self._random_id @@ -145,7 +139,7 @@ class SparkSession: @property def _random_id(self) -> str: - id = self._random_name + id = "r" + uuid.uuid4().hex self.known_ids.add(id) return id diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py index 575d18a..4b9fbb1 100644 --- a/sqlglot/dataframe/sql/util.py +++ b/sqlglot/dataframe/sql/util.py @@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T if not expression.args.get("joins"): return [] - left_table = expression.args["from"].args["expressions"][0] + left_table = expression.args["from"].this other_tables = [join.this for join in expression.args["joins"]] return [left_table] + other_tables -- cgit v1.2.3