summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi14
-rw-r--r--sqlglot/dataframe/sql/column.py46
-rw-r--r--sqlglot/dataframe/sql/dataframe.py158
-rw-r--r--sqlglot/dataframe/sql/functions.py100
-rw-r--r--sqlglot/dataframe/sql/group.py10
-rw-r--r--sqlglot/dataframe/sql/normalize.py13
-rw-r--r--sqlglot/dataframe/sql/readwriter.py16
-rw-r--r--sqlglot/dataframe/sql/session.py17
-rw-r--r--sqlglot/dataframe/sql/types.py6
-rw-r--r--sqlglot/dataframe/sql/window.py27
10 files changed, 305 insertions, 102 deletions
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
index f1a03ea..67c8c09 100644
--- a/sqlglot/dataframe/sql/_typing.pyi
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -10,11 +10,17 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar(
- "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ "ColumnLiterals",
+ bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar(
- "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ "ColumnOrLiteral",
+ bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
+)
+SchemaInput = t.TypeVar(
+ "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
+)
+OutputExpressionContainer = t.TypeVar(
+ "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
)
-SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
-OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index e66aaa8..f9e1c5b 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -18,7 +18,11 @@ class Column:
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
- self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
+
+ expression = sqlglot.maybe_parse(expression, dialect="spark")
+ if expression is None:
+ raise ValueError(f"Could not parse {expression}")
+ self.expression: exp.Expression = expression
def __repr__(self):
return repr(self.expression)
@@ -135,21 +139,29 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
- k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
+ k: [Column.ensure_col(x).expression for x in v]
+ if is_iterable(v)
+ else Column.ensure_col(v).expression
for k, v in kwargs.items()
}
new_expression = (
callable_expression(**ensure_expression_values)
if ensured_column is None
- else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
+ else callable_expression(
+ this=ensured_column.column_expression, **ensure_expression_values
+ )
)
return Column(new_expression)
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
+ return Column(
+ klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
+ )
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
+ return Column(
+ klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
+ )
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
return Column(klass(this=self.column_expression, **kwargs))
@@ -188,7 +200,7 @@ class Column:
expression.set("table", exp.to_identifier(table_name))
return Column(expression)
- def sql(self, **kwargs) -> Column:
+ def sql(self, **kwargs) -> str:
return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> Column:
@@ -265,10 +277,14 @@ class Column:
)
def like(self, other: str):
- return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
+ return self.invoke_expression_over_column(
+ self, exp.Like, expression=self._lit(other).expression
+ )
def ilike(self, other: str):
- return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
+ return self.invoke_expression_over_column(
+ self, exp.ILike, expression=self._lit(other).expression
+ )
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
@@ -287,10 +303,18 @@ class Column:
lowerBound: t.Union[ColumnOrLiteral],
upperBound: t.Union[ColumnOrLiteral],
) -> Column:
- lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
- upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ lower_bound_exp = (
+ self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
+ )
+ upper_bound_exp = (
+ self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ )
return Column(
- exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
+ exp.Between(
+ this=self.column_expression,
+ low=lower_bound_exp.expression,
+ high=upper_bound_exp.expression,
+ )
)
def over(self, window: WindowSpec) -> Column:
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 322dcf2..40cd6c9 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import qualify_columns
if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
+ from sqlglot.dataframe.sql._typing import (
+ ColumnLiterals,
+ ColumnOrLiteral,
+ ColumnOrName,
+ OutputExpressionContainer,
+ )
from sqlglot.dataframe.sql.session import SparkSession
@@ -83,7 +88,9 @@ class DataFrame:
return from_exp.alias_or_name
table_alias = from_exp.find(exp.TableAlias)
if not table_alias:
- raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
+ raise RuntimeError(
+ f"Could not find an alias name for this expression: {self.expression}"
+ )
return table_alias.alias_or_name
return self.expression.ctes[-1].alias
@@ -132,12 +139,16 @@ class DataFrame:
cte.set("sequence_id", sequence_id or self.sequence_id)
return cte, name
- def _ensure_list_of_columns(
- self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
- ) -> t.List[Column]:
- columns = ensure_list(cols)
- columns = Column.ensure_cols(columns)
- return columns
+ @t.overload
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
+ ...
+
+ @t.overload
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
+ ...
+
+ def _ensure_list_of_columns(self, cols):
+ return Column.ensure_cols(ensure_list(cols))
def _ensure_and_normalize_cols(self, cols):
cols = self._ensure_list_of_columns(cols)
@@ -153,10 +164,16 @@ class DataFrame:
df = self._resolve_pending_hints()
sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy()
- cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
- new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
+ cte_expression, cte_name = df._create_cte_from_expression(
+ expression=expression, sequence_id=sequence_id
+ )
+ new_expression = df._add_ctes_to_expression(
+ exp.Select(), expression.ctes + [cte_expression]
+ )
sel_columns = df._get_outer_select_columns(cte_expression)
- new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
+ new_expression = new_expression.from_(cte_name).select(
+ *[x.alias_or_name for x in sel_columns]
+ )
return df.copy(expression=new_expression, sequence_id=sequence_id)
def _resolve_pending_hints(self) -> DataFrame:
@@ -169,16 +186,23 @@ class DataFrame:
hint_expression.args.get("expressions").append(hint)
df.pending_hints.remove(hint)
- join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
+ join_aliases = {
+ join_table.alias_or_name
+ for join_table in get_tables_from_expression_with_join(expression)
+ }
if join_aliases:
for hint in df.pending_join_hints:
for sequence_id_expression in hint.expressions:
sequence_id_or_name = sequence_id_expression.alias_or_name
sequence_ids_to_match = [sequence_id_or_name]
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
- sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
+ sequence_id_or_name
+ ]
matching_ctes = [
- cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
+ cte
+ for cte in reversed(expression.ctes)
+ if cte.args["sequence_id"] in sequence_ids_to_match
]
for matching_cte in matching_ctes:
if matching_cte.alias_or_name in join_aliases:
@@ -193,9 +217,14 @@ class DataFrame:
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
hint_name = hint_name.upper()
hint_expression = (
- exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
+ exp.JoinHint(
+ this=hint_name,
+ expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
+ )
if hint_name in JOIN_HINTS
- else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
+ else exp.Anonymous(
+ this=hint_name, expressions=[parameter.expression for parameter in args]
+ )
)
new_df = self.copy()
new_df.pending_hints.append(hint_expression)
@@ -245,7 +274,9 @@ class DataFrame:
def _get_select_expressions(
self,
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
- select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
+ select_expressions: t.List[
+ t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
+ ] = []
main_select_ctes: t.List[exp.CTE] = []
for cte in self.expression.ctes:
cache_storage_level = cte.args.get("cache_storage_level")
@@ -279,14 +310,19 @@ class DataFrame:
cache_table_name = df._create_hash_from_expression(select_expression)
cache_table = exp.to_table(cache_table_name)
original_alias_name = select_expression.args["cte_alias_name"]
- replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
+
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
+ cache_table_name
+ )
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
exp.Literal.string(cache_storage_level),
]
- expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
+ expression = exp.Cache(
+ this=cache_table, expression=select_expression, lazy=True, options=options
+ )
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
@@ -305,7 +341,9 @@ class DataFrame:
raise ValueError(f"Invalid expression type: {expression_type}")
output_expressions.append(expression)
- return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
+ return [
+ expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
+ ]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@@ -317,7 +355,9 @@ class DataFrame:
if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table]
if ambiguous_cols:
- join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
+ join_table_identifiers = [
+ x.this for x in get_tables_from_expression_with_join(self.expression)
+ ]
cte_names_in_join = [x.this for x in join_table_identifiers]
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
@@ -367,14 +407,20 @@ class DataFrame:
@operation(Operation.FROM)
def join(
- self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
+ self,
+ other_df: DataFrame,
+ on: t.Union[str, t.List[str], Column, t.List[Column]],
+ how: str = "inner",
+ **kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column):
- join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
+ join_columns = [
+ Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
+ ]
join_clause = functools.reduce(
lambda x, y: x & y,
[
@@ -402,7 +448,9 @@ class DataFrame:
for column in self._get_outer_select_columns(other_df)
]
column_value_mapping = {
- column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql(): column
for column in other_columns + self_columns + join_columns
}
all_columns = [
@@ -410,16 +458,22 @@ class DataFrame:
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
]
new_df = self.copy(
- expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
+ expression=self.expression.join(
+ other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
+ )
+ )
+ new_df.expression = new_df._add_ctes_to_expression(
+ new_df.expression, other_df.expression.ctes
)
- new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns)
return new_df
@operation(Operation.ORDER_BY)
def orderBy(
- self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
+ self,
+ *cols: t.Union[str, Column],
+ ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
) -> DataFrame:
"""
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
@@ -429,7 +483,10 @@ class DataFrame:
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
x
- for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
+ for x in [
+ i if isinstance(col.expression, exp.Ordered) else None
+ for i, col in enumerate(columns)
+ ]
if x is not None
]
if ascending is None:
@@ -478,7 +535,9 @@ class DataFrame:
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
r_expressions.append(r_column)
- r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ r_df = (
+ other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ )
l_df = self.copy()
if allowMissingColumns:
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
@@ -536,7 +595,9 @@ class DataFrame:
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
)
- if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
+ if_null_checks = [
+ F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
+ ]
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
num_nulls = nulls_added_together.alias("num_nulls")
new_df = new_df.select(num_nulls, append=True)
@@ -576,11 +637,15 @@ class DataFrame:
value_columns = [lit(value) for value in values]
null_replacement_mapping = {
- column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
+ column.alias_or_name: (
+ F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
+ )
for column, value in zip(columns, value_columns)
}
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
- null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
+ null_replacement_columns = [
+ null_replacement_mapping[column.alias_or_name] for column in all_columns
+ ]
new_df = new_df.select(*null_replacement_columns)
return new_df
@@ -589,12 +654,11 @@ class DataFrame:
self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
- subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
) -> DataFrame:
from sqlglot.dataframe.sql.functions import lit
old_values = None
- subset = ensure_list(subset)
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
@@ -605,7 +669,9 @@ class DataFrame:
new_values = list(to_replace.values())
elif not old_values and isinstance(to_replace, list):
assert isinstance(value, list), "value must be a list since the replacements are a list"
- assert len(to_replace) == len(value), "the replacements and values must be the same length"
+ assert len(to_replace) == len(
+ value
+ ), "the replacements and values must be the same length"
old_values = to_replace
new_values = value
else:
@@ -635,7 +701,9 @@ class DataFrame:
def withColumn(self, colName: str, col: Column) -> DataFrame:
col = self._ensure_and_normalize_col(col)
existing_col_names = self.expression.named_selects
- existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
+ existing_col_index = (
+ existing_col_names.index(colName) if colName in existing_col_names else None
+ )
if existing_col_index:
expression = self.expression.copy()
expression.expressions[existing_col_index] = col.expression
@@ -645,7 +713,11 @@ class DataFrame:
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
expression = self.expression.copy()
- existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
+ existing_columns = [
+ expression
+ for expression in expression.expressions
+ if expression.alias_or_name == existing
+ ]
if not existing_columns:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
@@ -674,15 +746,19 @@ class DataFrame:
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
parameter_list = ensure_list(parameters)
parameter_columns = (
- self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
+ self._ensure_list_of_columns(parameter_list)
+ if parameters
+ else Column.ensure_cols([self.sequence_id])
)
return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
- def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
- num_partitions = Column.ensure_cols(ensure_list(numPartitions))
+ def repartition(
+ self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
+ ) -> DataFrame:
+ num_partition_cols = self._ensure_list_of_columns(numPartitions)
columns = self._ensure_and_normalize_cols(cols)
- args = num_partitions + columns
+ args = num_partition_cols + columns
return self._hint("repartition", args)
@operation(Operation.NO_OP)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index bc002e5..dbfb06f 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value)
- return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
+ return Column(
+ glotexp.Case(
+ ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
+ )
+ )
def asc(col: ColumnOrName) -> Column:
@@ -407,7 +411,9 @@ def percentile_approx(
return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
- return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
+ return Column.invoke_expression_over_column(
+ col, glotexp.ApproxQuantile, quantile=lit(percentage)
+ )
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "FACTORIAL")
-def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
+def lag(
+ col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
+) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LAG", offset, default)
if offset != 1:
@@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
return Column.invoke_anonymous_function(col, "LAG")
-def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
+def lead(
+ col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
+) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
if offset != 1:
@@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
return Column.invoke_anonymous_function(col, "LEAD")
-def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
+def nth_value(
+ col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
+) -> Column:
if ignoreNulls is not None:
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
if offset != 1:
@@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
-def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
+def months_between(
+ date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
+) -> Column:
if roundOff is None:
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
@@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
-def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
+def unix_timestamp(
+ timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
+) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
+ return Column.invoke_expression_over_column(
+ timestamp, glotexp.StrToUnix, format=lit(format)
+ )
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
@@ -642,7 +660,9 @@ def window(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
)
if slideDuration is not None:
- return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
+ )
if startTime is not None:
return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
@@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
+ return Column.invoke_expression_over_column(
+ None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
+ )
def decode(col: ColumnOrName, charset: str) -> Column:
@@ -768,7 +790,9 @@ def overlay(
def sentences(
- string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
+ string: ColumnOrName,
+ language: t.Optional[ColumnOrName] = None,
+ country: t.Optional[ColumnOrName] = None,
) -> Column:
if language is not None and country is not None:
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
@@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
if pos is not None:
- return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
+ return Column.invoke_expression_over_column(
+ str, glotexp.StrPosition, substr=substr_col, position=pos
+ )
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
@@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column(
- None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
+ None,
+ glotexp.VarMap,
+ keys=array(*cols[::2]).expression,
+ values=array(*cols[1::2]).expression,
)
@@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
+ return Column.invoke_expression_over_column(
+ col, glotexp.ArrayContains, expression=value_col.expression
+ )
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
-def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
+def slice(
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
+) -> Column:
start_col = start if isinstance(start, Column) else lit(start)
length_col = length if isinstance(length, Column) else lit(length)
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
-def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
+def array_join(
+ col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
+) -> Column:
if null_replacement is not None:
- return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
+ return Column.invoke_anonymous_function(
+ col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
+ )
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1:
return Column.invoke_anonymous_function(cols[0], "CONCAT")
- return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
+ return Column.invoke_anonymous_function(
+ cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
+ )
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
-def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
+def sequence(
+ start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
+) -> Column:
if step is not None:
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
@@ -1103,12 +1144,15 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge)
if finish is not None:
finish_exp = _get_lambda_from_func(finish)
- return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
+ return Column.invoke_anonymous_function(
+ col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
+ )
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform(
- col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
@@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
-def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
+def filter(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
-def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
+def zip_with(
+ left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
+) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
@@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable):
- variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
+ variables = [
+ glotexp.to_identifier(x, quoted=_lambda_quoted(x))
+ for x in lambda_expression.__code__.co_varnames
+ ]
return glotexp.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,
diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py
index 947aace..ba27c17 100644
--- a/sqlglot/dataframe/sql/group.py
+++ b/sqlglot/dataframe/sql/group.py
@@ -17,7 +17,9 @@ class GroupedData:
self.last_op = last_op
self.group_by_cols = group_by_cols
- def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
+ def _get_function_applied_columns(
+ self, func_name: str, cols: t.Tuple[str, ...]
+ ) -> t.List[Column]:
func_name = func_name.lower()
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
@@ -30,9 +32,9 @@ class GroupedData:
)
cols = self._df._ensure_and_normalize_cols(columns)
- expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
- *[x.expression for x in self.group_by_cols + cols], append=False
- )
+ expression = self._df.expression.group_by(
+ *[x.expression for x in self.group_by_cols]
+ ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
return self._df.copy(expression=expression)
def count(self) -> DataFrame:
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index 1513946..75feba7 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
-def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+def replace_alias_name_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
@@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
- join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
- ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ join_table_aliases = [
+ x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
+ ]
+ ctes_in_join = [
+ cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
+ ]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name)
@@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
- values = ensure_list(values)
results = []
for value in values:
if isinstance(value, str):
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index 4830035..febc664 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -19,12 +19,19 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
- return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+ return DataFrame(
+ self.spark,
+ exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
+ )
class DataFrameWriter:
def __init__(
- self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ self,
+ df: DataFrame,
+ spark: t.Optional[SparkSession] = None,
+ mode: t.Optional[str] = None,
+ by_name: bool = False,
):
self._df = df
self._spark = spark or df.spark
@@ -33,7 +40,10 @@ class DataFrameWriter:
def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter(
- **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ **{
+ k[1:] if k.startswith("_") else k: v
+ for k, v in object_to_dict(self, **kwargs).items()
+ }
)
def sql(self, **kwargs) -> t.List[str]:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 1ea86d1..8cb16ef 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -67,13 +67,20 @@ class SparkSession:
data_expressions = [
exp.Tuple(
- expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
+ expressions=list(
+ map(
+ lambda x: F.lit(x).expression,
+ row if not isinstance(row, dict) else row.values(),
+ )
+ )
)
for row in data
]
sel_columns = [
- F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
+ F.col(name).cast(data_type).alias(name).expression
+ if data_type is not None
+ else F.col(name).expression
for name, data_type in column_mapping.items()
]
@@ -106,10 +113,12 @@ class SparkSession:
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
- df = DataFrame(self, select_expression, output_expression_container=expression)
+ df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte()
else:
- raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
+ raise ValueError(
+ "Unknown expression type provided in the SQL. Please create an issue with the SQL."
+ )
return df
@property
diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py
index dc5c05a..a63e505 100644
--- a/sqlglot/dataframe/sql/types.py
+++ b/sqlglot/dataframe/sql/types.py
@@ -158,7 +158,11 @@ class MapType(DataType):
class StructField(DataType):
def __init__(
- self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
+ self,
+ name: str,
+ dataType: DataType,
+ nullable: bool = True,
+ metadata: t.Optional[t.Dict[str, t.Any]] = None,
):
self.name = name
self.dataType = dataType
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
index 842f366..c54c07e 100644
--- a/sqlglot/dataframe/sql/window.py
+++ b/sqlglot/dataframe/sql/window.py
@@ -74,8 +74,13 @@ class WindowSpec:
window_spec.expression.args["order"].set("expressions", order_by)
return window_spec
- def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
- kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
+ def _calc_start_end(
+ self, start: int, end: int
+ ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
+ kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
+ "start_side": None,
+ "end_side": None,
+ }
if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW"
else:
@@ -83,7 +88,9 @@ class WindowSpec:
**kwargs,
**{
"start_side": "PRECEDING",
- "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
+ "start": "UNBOUNDED"
+ if start <= Window.unboundedPreceding
+ else F.lit(start).expression,
},
}
if end == Window.currentRow:
@@ -93,7 +100,9 @@ class WindowSpec:
**kwargs,
**{
"end_side": "FOLLOWING",
- "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
+ "end": "UNBOUNDED"
+ if end >= Window.unboundedFollowing
+ else F.lit(end).expression,
},
}
return kwargs
@@ -103,7 +112,10 @@ class WindowSpec:
spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS"
window_spec.expression.set(
- "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ "spec",
+ exp.WindowSpec(
+ **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
+ ),
)
return window_spec
@@ -112,6 +124,9 @@ class WindowSpec:
spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE"
window_spec.expression.set(
- "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ "spec",
+ exp.WindowSpec(
+ **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
+ ),
)
return window_spec