summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/_typing.py (renamed from sqlglot/dataframe/sql/_typing.pyi)4
-rw-r--r--sqlglot/dataframe/sql/dataframe.py12
-rw-r--r--sqlglot/dataframe/sql/operations.py2
-rw-r--r--sqlglot/dataframe/sql/session.py20
-rw-r--r--sqlglot/dataframe/sql/util.py2
5 files changed, 18 insertions, 22 deletions
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.py
index 1682ec1..fb46026 100644
--- a/sqlglot/dataframe/sql/_typing.pyi
+++ b/sqlglot/dataframe/sql/_typing.py
@@ -11,6 +11,8 @@ if t.TYPE_CHECKING:
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrName = t.Union[Column, str]
-ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ColumnOrLiteral = t.Union[
+ Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime
+]
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index f3a6f6f..3fc9232 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -127,7 +127,7 @@ class DataFrame:
sequence_id: t.Optional[str] = None,
**kwargs,
) -> t.Tuple[exp.CTE, str]:
- name = self.spark._random_name
+ name = self._create_hash_from_expression(expression)
expression_to_cte = expression.copy()
expression_to_cte.set("with", None)
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
@@ -263,7 +263,7 @@ class DataFrame:
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
@classmethod
- def _create_hash_from_expression(cls, expression: exp.Select):
+ def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
value = expression.sql(dialect="spark").encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
@@ -299,7 +299,7 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- select_expression = optimize_func(select_expression, identify="always")
+ select_expression = t.cast(exp.Select, optimize_func(select_expression))
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@@ -570,9 +570,9 @@ class DataFrame:
r_expressions.append(l_column)
r_columns_unused.remove(l_column)
else:
- r_expressions.append(exp.alias_(exp.Null(), l_column))
+ r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
for r_column in r_columns_unused:
- l_expressions.append(exp.alias_(exp.Null(), r_column))
+ l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
r_expressions.append(r_column)
r_df = (
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
@@ -761,7 +761,7 @@ class DataFrame:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
if isinstance(existing_column, exp.Column):
- existing_column.replace(exp.alias_(existing_column.copy(), new))
+ existing_column.replace(exp.alias_(existing_column, new))
else:
existing_column.set("alias", exp.to_identifier(new))
return self.copy(expression=expression)
diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py
index d51335c..e4c106b 100644
--- a/sqlglot/dataframe/sql/operations.py
+++ b/sqlglot/dataframe/sql/operations.py
@@ -41,7 +41,7 @@ def operation(op: Operation):
self.last_op = Operation.NO_OP
last_op = self.last_op
new_op = op if op != Operation.NO_OP else last_op
- if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
+ if new_op < last_op or (last_op == new_op == Operation.SELECT):
self = self._convert_leaf_to_cte()
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index af589b0..b883359 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -87,15 +87,13 @@ class SparkSession:
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
- expressions=[
- exp.Values(
- expressions=data_expressions,
- alias=exp.TableAlias(
- this=exp.to_identifier(self._auto_incrementing_name),
- columns=[exp.to_identifier(col_name) for col_name in column_mapping],
- ),
+ this=exp.Values(
+ expressions=data_expressions,
+ alias=exp.TableAlias(
+ this=exp.to_identifier(self._auto_incrementing_name),
+ columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
- ],
+ ),
),
}
@@ -128,10 +126,6 @@ class SparkSession:
return name
@property
- def _random_name(self) -> str:
- return "r" + uuid.uuid4().hex
-
- @property
def _random_branch_id(self) -> str:
id = self._random_id
self.known_branch_ids.add(id)
@@ -145,7 +139,7 @@ class SparkSession:
@property
def _random_id(self) -> str:
- id = self._random_name
+ id = "r" + uuid.uuid4().hex
self.known_ids.add(id)
return id
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
index 575d18a..4b9fbb1 100644
--- a/sqlglot/dataframe/sql/util.py
+++ b/sqlglot/dataframe/sql/util.py
@@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T
if not expression.args.get("joins"):
return []
- left_table = expression.args["from"].args["expressions"][0]
+ left_table = expression.args["from"].this
other_tables = [join.this for join in expression.args["joins"]]
return [left_table] + other_tables