diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 151 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 15 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 12 |
4 files changed, 114 insertions, 66 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 609b2a4..a8b89d1 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -176,7 +176,7 @@ class Column: return isinstance(self.expression, exp.Column) @property - def column_expression(self) -> exp.Column: + def column_expression(self) -> t.Union[exp.Column, exp.Literal]: return self.expression.unalias() @property diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 93bdf75..f3a6f6f 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -16,7 +16,7 @@ from sqlglot.dataframe.sql.readwriter import DataFrameWriter from sqlglot.dataframe.sql.transforms import replace_id_value from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window -from sqlglot.helper import ensure_list, object_to_dict +from sqlglot.helper import ensure_list, object_to_dict, seq_get from sqlglot.optimizer import optimize as optimize_func if t.TYPE_CHECKING: @@ -146,9 +146,9 @@ class DataFrame: def _ensure_list_of_columns(self, cols): return Column.ensure_cols(ensure_list(cols)) - def _ensure_and_normalize_cols(self, cols): + def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): cols = self._ensure_list_of_columns(cols) - normalize(self.spark, self.expression, cols) + normalize(self.spark, expression or self.expression, cols) return cols def _ensure_and_normalize_col(self, col): @@ -355,12 +355,20 @@ class DataFrame: cols = self._ensure_and_normalize_cols(cols) kwargs["append"] = kwargs.get("append", False) if self.expression.args.get("joins"): - ambiguous_cols = [col for col in cols if not col.column_expression.table] + ambiguous_cols = [ + col + for col in cols + if isinstance(col.column_expression, exp.Column) and not col.column_expression.table + ] if ambiguous_cols: join_table_identifiers = [ x.this for x in get_tables_from_expression_with_join(self.expression) ] cte_names_in_join = [x.this for x in join_table_identifiers] + # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right + # and therefore we allow multiple columns with the same name in the result. This matches the behavior + # of Spark. + resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} for ambiguous_col in ambiguous_cols: ctes_with_column = [ cte @@ -368,13 +376,14 @@ class DataFrame: if cte.alias_or_name in cte_names_in_join and ambiguous_col.alias_or_name in cte.this.named_selects ] - # If the select column does not specify a table and there is a join - # then we assume they are referring to the left table - if len(ctes_with_column) > 1: - table_identifier = self.expression.args["from"].args["expressions"][0].this + # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, + # use the same CTE we used before + cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) + if cte: + resolved_column_position[ambiguous_col] += 1 else: - table_identifier = ctes_with_column[0].args["alias"].this - ambiguous_col.expression.set("table", table_identifier) + cte = ctes_with_column[resolved_column_position[ambiguous_col]] + ambiguous_col.expression.set("table", cte.alias_or_name) return self.copy( expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs ) @@ -416,59 +425,87 @@ class DataFrame: **kwargs, ) -> DataFrame: other_df = other_df._convert_leaf_to_cte() - pre_join_self_latest_cte_name = self.latest_cte_name - columns = self._ensure_and_normalize_cols(on) - join_type = how.replace("_", " ") - if isinstance(columns[0].expression, exp.Column): - join_columns = [ - Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns + join_columns = self._ensure_list_of_columns(on) + # We will determine actual "join on" expression later so we don't provide it at first + join_expression = self.expression.join( + other_df.latest_cte_name, join_type=how.replace("_", " ") + ) + join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) + self_columns = self._get_outer_select_columns(join_expression) + other_columns = self._get_outer_select_columns(other_df) + # Determines the join clause and select columns to be used passed on what type of columns were provided for + # the join. The columns returned changes based on how the on expression is provided. + if isinstance(join_columns[0].expression, exp.Column): + """ + Unique characteristics of join on column names only: + * The column names are put at the front of the select list + * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) + """ + table_names = [ + table.alias_or_name + for table in get_tables_from_expression_with_join(join_expression) ] + potential_ctes = [ + cte + for cte in join_expression.ctes + if cte.alias_or_name in table_names + and cte.alias_or_name != other_df.latest_cte_name + ] + # Determine the table to reference for the left side of the join by checking each of the left side + # tables and see if they have the column being referenced. + join_column_pairs = [] + for join_column in join_columns: + num_matching_ctes = 0 + for cte in potential_ctes: + if join_column.alias_or_name in cte.this.named_selects: + left_column = join_column.copy().set_table_name(cte.alias_or_name) + right_column = join_column.copy().set_table_name(other_df.latest_cte_name) + join_column_pairs.append((left_column, right_column)) + num_matching_ctes += 1 + if num_matching_ctes > 1: + raise ValueError( + f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." + ) + elif num_matching_ctes == 0: + raise ValueError( + f"Column {join_column.alias_or_name} does not exist in any of the tables." + ) join_clause = functools.reduce( lambda x, y: x & y, - [ - col.copy().set_table_name(pre_join_self_latest_cte_name) - == col.copy().set_table_name(other_df.latest_cte_name) - for col in columns - ], + [left_column == right_column for left_column, right_column in join_column_pairs], ) - else: - if len(columns) > 1: - columns = [functools.reduce(lambda x, y: x & y, columns)] - join_clause = columns[0] - join_columns = [ - Column(x).set_table_name(pre_join_self_latest_cte_name) - if i % 2 == 0 - else Column(x).set_table_name(other_df.latest_cte_name) - for i, x in enumerate(join_clause.expression.find_all(exp.Column)) + join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] + # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list + select_column_names = [ + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql() + for column in self_columns + other_columns ] - self_columns = [ - column.set_table_name(pre_join_self_latest_cte_name, copy=True) - for column in self._get_outer_select_columns(self) - ] - other_columns = [ - column.set_table_name(other_df.latest_cte_name, copy=True) - for column in self._get_outer_select_columns(other_df) - ] - column_value_mapping = { - column.alias_or_name - if not isinstance(column.expression.this, exp.Star) - else column.sql(): column - for column in other_columns + self_columns + join_columns - } - all_columns = [ - column_value_mapping[name] - for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} - ] - new_df = self.copy( - expression=self.expression.join( - other_df.latest_cte_name, on=join_clause.expression, join_type=join_type - ) - ) - new_df.expression = new_df._add_ctes_to_expression( - new_df.expression, other_df.expression.ctes - ) + select_column_names = [ + column_name + for column_name in select_column_names + if column_name not in join_column_names + ] + select_column_names = join_column_names + select_column_names + else: + """ + Unique characteristics of join on expressions: + * There is no deduplication of the results. + * The left join dataframe columns go first and right come after. No sort preference is given to join columns + """ + join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) + if len(join_columns) > 1: + join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] + join_clause = join_columns[0] + select_column_names = [column.alias_or_name for column in self_columns + other_columns] + + # Update the on expression with the actual join clause to replace the dummy one from before + join_expression.args["joins"][-1].set("on", join_clause.expression) + new_df = self.copy(expression=join_expression) + new_df.pending_join_hints.extend(self.pending_join_hints) new_df.pending_hints.extend(other_df.pending_hints) - new_df = new_df.select.__wrapped__(new_df, *all_columns) + new_df = new_df.select.__wrapped__(new_df, *select_column_names) return new_df @operation(Operation.ORDER_BY) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index f77b4f8..993d869 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -577,11 +577,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days) + return Column.invoke_expression_over_column( + col, expression.DateAdd, expression=days, unit=expression.Var(this="day") + ) def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, expression.DateSub, expression=days) + return Column.invoke_expression_over_column( + col, expression.DateSub, expression=days, unit=expression.Var(this="day") + ) def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: @@ -695,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column: def md5(col: ColumnOrName) -> Column: column = col if isinstance(col, Column) else lit(col) - return Column.invoke_anonymous_function(column, "MD5") + return Column.invoke_expression_over_column(column, expression.MD5) def sha1(col: ColumnOrName) -> Column: column = col if isinstance(col, Column) else lit(col) - return Column.invoke_anonymous_function(column, "SHA1") + return Column.invoke_expression_over_column(column, expression.SHA) def sha2(col: ColumnOrName, numBits: int) -> Column: column = col if isinstance(col, Column) else lit(col) - num_bits = lit(numBits) - return Column.invoke_anonymous_function(column, "SHA2", num_bits) + return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits)) def hash(*cols: ColumnOrName) -> Column: diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index febc664..cc2f181 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,7 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.helper import object_to_dict +from sqlglot.helper import object_to_dict, should_identify if t.TYPE_CHECKING: from sqlglot.dataframe.sql.dataframe import DataFrame @@ -19,9 +19,17 @@ class DataFrameReader: from sqlglot.dataframe.sql.dataframe import DataFrame sqlglot.schema.add_table(tableName) + return DataFrame( self.spark, - exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), + exp.Select() + .from_(tableName) + .select( + *( + column if should_identify(column, "safe") else f'"{column}"' + for column in sqlglot.schema.column_names(tableName) + ) + ), ) |