summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/dataframe/sql/column.py2
-rw-r--r--sqlglot/dataframe/sql/dataframe.py151
-rw-r--r--sqlglot/dataframe/sql/functions.py15
-rw-r--r--sqlglot/dataframe/sql/readwriter.py12
4 files changed, 114 insertions, 66 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index 609b2a4..a8b89d1 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -176,7 +176,7 @@ class Column:
return isinstance(self.expression, exp.Column)
@property
- def column_expression(self) -> exp.Column:
+ def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
return self.expression.unalias()
@property
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 93bdf75..f3a6f6f 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -16,7 +16,7 @@ from sqlglot.dataframe.sql.readwriter import DataFrameWriter
from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
-from sqlglot.helper import ensure_list, object_to_dict
+from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
if t.TYPE_CHECKING:
@@ -146,9 +146,9 @@ class DataFrame:
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
- def _ensure_and_normalize_cols(self, cols):
+ def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
cols = self._ensure_list_of_columns(cols)
- normalize(self.spark, self.expression, cols)
+ normalize(self.spark, expression or self.expression, cols)
return cols
def _ensure_and_normalize_col(self, col):
@@ -355,12 +355,20 @@ class DataFrame:
cols = self._ensure_and_normalize_cols(cols)
kwargs["append"] = kwargs.get("append", False)
if self.expression.args.get("joins"):
- ambiguous_cols = [col for col in cols if not col.column_expression.table]
+ ambiguous_cols = [
+ col
+ for col in cols
+ if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
+ ]
if ambiguous_cols:
join_table_identifiers = [
x.this for x in get_tables_from_expression_with_join(self.expression)
]
cte_names_in_join = [x.this for x in join_table_identifiers]
+ # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
+ # and therefore we allow multiple columns with the same name in the result. This matches the behavior
+ # of Spark.
+ resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
cte
@@ -368,13 +376,14 @@ class DataFrame:
if cte.alias_or_name in cte_names_in_join
and ambiguous_col.alias_or_name in cte.this.named_selects
]
- # If the select column does not specify a table and there is a join
- # then we assume they are referring to the left table
- if len(ctes_with_column) > 1:
- table_identifier = self.expression.args["from"].args["expressions"][0].this
+ # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
+ # use the same CTE we used before
+ cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
+ if cte:
+ resolved_column_position[ambiguous_col] += 1
else:
- table_identifier = ctes_with_column[0].args["alias"].this
- ambiguous_col.expression.set("table", table_identifier)
+ cte = ctes_with_column[resolved_column_position[ambiguous_col]]
+ ambiguous_col.expression.set("table", cte.alias_or_name)
return self.copy(
expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
)
@@ -416,59 +425,87 @@ class DataFrame:
**kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
- pre_join_self_latest_cte_name = self.latest_cte_name
- columns = self._ensure_and_normalize_cols(on)
- join_type = how.replace("_", " ")
- if isinstance(columns[0].expression, exp.Column):
- join_columns = [
- Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
+ join_columns = self._ensure_list_of_columns(on)
+ # We will determine actual "join on" expression later so we don't provide it at first
+ join_expression = self.expression.join(
+ other_df.latest_cte_name, join_type=how.replace("_", " ")
+ )
+ join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
+ self_columns = self._get_outer_select_columns(join_expression)
+ other_columns = self._get_outer_select_columns(other_df)
+ # Determines the join clause and select columns to be used passed on what type of columns were provided for
+ # the join. The columns returned changes based on how the on expression is provided.
+ if isinstance(join_columns[0].expression, exp.Column):
+ """
+ Unique characteristics of join on column names only:
+ * The column names are put at the front of the select list
+ * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
+ """
+ table_names = [
+ table.alias_or_name
+ for table in get_tables_from_expression_with_join(join_expression)
]
+ potential_ctes = [
+ cte
+ for cte in join_expression.ctes
+ if cte.alias_or_name in table_names
+ and cte.alias_or_name != other_df.latest_cte_name
+ ]
+ # Determine the table to reference for the left side of the join by checking each of the left side
+ # tables and see if they have the column being referenced.
+ join_column_pairs = []
+ for join_column in join_columns:
+ num_matching_ctes = 0
+ for cte in potential_ctes:
+ if join_column.alias_or_name in cte.this.named_selects:
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
+ right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
+ join_column_pairs.append((left_column, right_column))
+ num_matching_ctes += 1
+ if num_matching_ctes > 1:
+ raise ValueError(
+ f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
+ )
+ elif num_matching_ctes == 0:
+ raise ValueError(
+ f"Column {join_column.alias_or_name} does not exist in any of the tables."
+ )
join_clause = functools.reduce(
lambda x, y: x & y,
- [
- col.copy().set_table_name(pre_join_self_latest_cte_name)
- == col.copy().set_table_name(other_df.latest_cte_name)
- for col in columns
- ],
+ [left_column == right_column for left_column, right_column in join_column_pairs],
)
- else:
- if len(columns) > 1:
- columns = [functools.reduce(lambda x, y: x & y, columns)]
- join_clause = columns[0]
- join_columns = [
- Column(x).set_table_name(pre_join_self_latest_cte_name)
- if i % 2 == 0
- else Column(x).set_table_name(other_df.latest_cte_name)
- for i, x in enumerate(join_clause.expression.find_all(exp.Column))
+ join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
+ # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
+ select_column_names = [
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql()
+ for column in self_columns + other_columns
]
- self_columns = [
- column.set_table_name(pre_join_self_latest_cte_name, copy=True)
- for column in self._get_outer_select_columns(self)
- ]
- other_columns = [
- column.set_table_name(other_df.latest_cte_name, copy=True)
- for column in self._get_outer_select_columns(other_df)
- ]
- column_value_mapping = {
- column.alias_or_name
- if not isinstance(column.expression.this, exp.Star)
- else column.sql(): column
- for column in other_columns + self_columns + join_columns
- }
- all_columns = [
- column_value_mapping[name]
- for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
- ]
- new_df = self.copy(
- expression=self.expression.join(
- other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
- )
- )
- new_df.expression = new_df._add_ctes_to_expression(
- new_df.expression, other_df.expression.ctes
- )
+ select_column_names = [
+ column_name
+ for column_name in select_column_names
+ if column_name not in join_column_names
+ ]
+ select_column_names = join_column_names + select_column_names
+ else:
+ """
+ Unique characteristics of join on expressions:
+ * There is no deduplication of the results.
+ * The left join dataframe columns go first and right come after. No sort preference is given to join columns
+ """
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
+ if len(join_columns) > 1:
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
+ join_clause = join_columns[0]
+ select_column_names = [column.alias_or_name for column in self_columns + other_columns]
+
+ # Update the on expression with the actual join clause to replace the dummy one from before
+ join_expression.args["joins"][-1].set("on", join_clause.expression)
+ new_df = self.copy(expression=join_expression)
+ new_df.pending_join_hints.extend(self.pending_join_hints)
new_df.pending_hints.extend(other_df.pending_hints)
- new_df = new_df.select.__wrapped__(new_df, *all_columns)
+ new_df = new_df.select.__wrapped__(new_df, *select_column_names)
return new_df
@operation(Operation.ORDER_BY)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index f77b4f8..993d869 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -577,11 +577,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
+ return Column.invoke_expression_over_column(
+ col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
+ )
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
+ return Column.invoke_expression_over_column(
+ col, expression.DateSub, expression=days, unit=expression.Var(this="day")
+ )
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
@@ -695,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column:
def md5(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_anonymous_function(column, "MD5")
+ return Column.invoke_expression_over_column(column, expression.MD5)
def sha1(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_anonymous_function(column, "SHA1")
+ return Column.invoke_expression_over_column(column, expression.SHA)
def sha2(col: ColumnOrName, numBits: int) -> Column:
column = col if isinstance(col, Column) else lit(col)
- num_bits = lit(numBits)
- return Column.invoke_anonymous_function(column, "SHA2", num_bits)
+ return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits))
def hash(*cols: ColumnOrName) -> Column:
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index febc664..cc2f181 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.helper import object_to_dict
+from sqlglot.helper import object_to_dict, should_identify
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
@@ -19,9 +19,17 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
+
return DataFrame(
self.spark,
- exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
+ exp.Select()
+ .from_(tableName)
+ .select(
+ *(
+ column if should_identify(column, "safe") else f'"{column}"'
+ for column in sqlglot.schema.column_names(tableName)
+ )
+ ),
)