summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/README.md224
-rw-r--r--sqlglot/dataframe/__init__.py0
-rw-r--r--sqlglot/dataframe/sql/__init__.py18
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi20
-rw-r--r--sqlglot/dataframe/sql/column.py295
-rw-r--r--sqlglot/dataframe/sql/dataframe.py730
-rw-r--r--sqlglot/dataframe/sql/functions.py1258
-rw-r--r--sqlglot/dataframe/sql/group.py57
-rw-r--r--sqlglot/dataframe/sql/normalize.py72
-rw-r--r--sqlglot/dataframe/sql/operations.py53
-rw-r--r--sqlglot/dataframe/sql/readwriter.py79
-rw-r--r--sqlglot/dataframe/sql/session.py148
-rw-r--r--sqlglot/dataframe/sql/transforms.py9
-rw-r--r--sqlglot/dataframe/sql/types.py208
-rw-r--r--sqlglot/dataframe/sql/util.py32
-rw-r--r--sqlglot/dataframe/sql/window.py117
16 files changed, 3320 insertions, 0 deletions
diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md
new file mode 100644
index 0000000..54d3856
--- /dev/null
+++ b/sqlglot/dataframe/README.md
@@ -0,0 +1,224 @@
+# PySpark DataFrame SQL Generator
+
+This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
+
+Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
+
+# How to use
+
+## Instructions
+* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library
+* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`
+* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`
+ * The column structure can be defined the following ways:
+ * Dictionary where the keys are column names and values are string of the Spark SQL type name
+ * Ex: {'cola': 'string', 'colb': 'int'}
+ * PySpark DataFrame `StructType` similar to when using `createDataFrame`
+ * Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
+ * A string of names and types similar to what is supported in `createDataFrame`
+ * Ex: `cola: STRING, colb: INT`
+ * [Not Recommended] A list of string column names without type
+ * Ex: ['cola', 'colb']
+ * The lack of types may limit functionality in future releases
+ * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally
+* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command
+ * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
+ * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects
+ * Ex: `.sql(pretty=True, dialect='bigquery')`
+
+## Examples
+
+```python
+import sqlglot
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import functions as F
+
+sqlglot.schema.add_table('employee', {
+ 'employee_id': 'INT',
+ 'fname': 'STRING',
+ 'lname': 'STRING',
+ 'age': 'INT',
+}) # Register the table structure prior to reading from the table
+
+spark = SparkSession()
+
+df = (
+ spark
+ .table('employee')
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+)
+
+print(df.sql(pretty=True)) # Spark will be the dialect used by default
+```
+Output:
+```sparksql
+SELECT
+ `employee`.`age` AS `age`,
+ COUNT(DISTINCT `employee`.`employee_id`) AS `num_employees`
+FROM `employee` AS `employee`
+GROUP BY
+ `employee`.`age`
+```
+
+## Registering Custom Schema Class
+
+The step of adding `sqlglot.schema.add_table` can be skipped if you have the column structure stored externally like in a file or from an external metadata table. This can be done by writing a class that implements the `sqlglot.schema.Schema` abstract class and then assigning that class to `sqlglot.schema`.
+
+```python
+import sqlglot
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.schema import Schema
+
+
+class ExternalSchema(Schema):
+ ...
+
+sqlglot.schema = ExternalSchema()
+
+spark = SparkSession()
+
+df = (
+ spark
+ .table('employee')
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+)
+
+print(df.sql(pretty=True))
+```
+
+## Example Implementations
+
+### Bigquery
+```python
+from google.cloud import bigquery
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import types
+from sqlglot.dataframe.sql import functions as F
+
+client = bigquery.Client()
+
+data = [
+ (1, "Jack", "Shephard", 34),
+ (2, "John", "Locke", 48),
+ (3, "Kate", "Austen", 34),
+ (4, "Claire", "Littleton", 22),
+ (5, "Hugo", "Reyes", 26),
+]
+schema = types.StructType([
+ types.StructField('employee_id', types.IntegerType(), False),
+ types.StructField('fname', types.StringType(), False),
+ types.StructField('lname', types.StringType(), False),
+ types.StructField('age', types.IntegerType(), False),
+])
+
+sql_statements = (
+ SparkSession()
+ .createDataFrame(data, schema)
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+ .sql(dialect="bigquery")
+)
+
+result = None
+for sql in sql_statements:
+ result = client.query(sql)
+
+assert result is not None
+for row in client.query(result):
+ print(f"Age: {row['age']}, Num Employees: {row['num_employees']}")
+```
+
+### Snowflake
+```python
+import os
+
+import snowflake.connector
+from sqlglot.dataframe.session import SparkSession
+from sqlglot.dataframe import types
+from sqlglot.dataframe import functions as F
+
+ctx = snowflake.connector.connect(
+ user=os.environ["SNOWFLAKE_USER"],
+ password=os.environ["SNOWFLAKE_PASS"],
+ account=os.environ["SNOWFLAKE_ACCOUNT"]
+)
+cs = ctx.cursor()
+
+data = [
+ (1, "Jack", "Shephard", 34),
+ (2, "John", "Locke", 48),
+ (3, "Kate", "Austen", 34),
+ (4, "Claire", "Littleton", 22),
+ (5, "Hugo", "Reyes", 26),
+]
+schema = types.StructType([
+ types.StructField('employee_id', types.IntegerType(), False),
+ types.StructField('fname', types.StringType(), False),
+ types.StructField('lname', types.StringType(), False),
+ types.StructField('age', types.IntegerType(), False),
+])
+
+sql_statements = (
+ SparkSession()
+ .createDataFrame(data, schema)
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("lname")).alias("num_employees"))
+ .sql(dialect="snowflake")
+)
+
+try:
+ for sql in sql_statements:
+ cs.execute(sql)
+ results = cs.fetchall()
+ for row in results:
+ print(f"Age: {row[0]}, Num Employees: {row[1]}")
+finally:
+ cs.close()
+ctx.close()
+```
+
+### Spark
+```python
+from pyspark.sql.session import SparkSession as PySparkSession
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import types
+from sqlglot.dataframe.sql import functions as F
+
+data = [
+ (1, "Jack", "Shephard", 34),
+ (2, "John", "Locke", 48),
+ (3, "Kate", "Austen", 34),
+ (4, "Claire", "Littleton", 22),
+ (5, "Hugo", "Reyes", 26),
+]
+schema = types.StructType([
+ types.StructField('employee_id', types.IntegerType(), False),
+ types.StructField('fname', types.StringType(), False),
+ types.StructField('lname', types.StringType(), False),
+ types.StructField('age', types.IntegerType(), False),
+])
+
+sql_statements = (
+ SparkSession()
+ .createDataFrame(data, schema)
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+ .sql(dialect="bigquery")
+)
+
+pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
+
+df = None
+for sql in sql_statements:
+ df = pyspark.sql(sql)
+
+assert df is not None
+df.show()
+```
+
+# Unsupportable Operations
+
+Any operation that lacks a way to represent it in SQL cannot be supported by this tool. An example of this would be rdd operations. Since the DataFrame API though is mostly modeled around SQL concepts most operations can be supported.
diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/sqlglot/dataframe/__init__.py
diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py
new file mode 100644
index 0000000..3f90802
--- /dev/null
+++ b/sqlglot/dataframe/sql/__init__.py
@@ -0,0 +1,18 @@
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
+from sqlglot.dataframe.sql.group import GroupedData
+from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql.window import Window, WindowSpec
+
+__all__ = [
+ "SparkSession",
+ "DataFrame",
+ "GroupedData",
+ "Column",
+ "DataFrameNaFunctions",
+ "Window",
+ "WindowSpec",
+ "DataFrameReader",
+ "DataFrameWriter",
+]
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
new file mode 100644
index 0000000..f1a03ea
--- /dev/null
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -0,0 +1,20 @@
+from __future__ import annotations
+
+import datetime
+import typing as t
+
+from sqlglot import expressions as exp
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.column import Column
+ from sqlglot.dataframe.sql.types import StructType
+
+ColumnLiterals = t.TypeVar(
+ "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+)
+ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
+ColumnOrLiteral = t.TypeVar(
+ "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+)
+SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
+OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
new file mode 100644
index 0000000..2391080
--- /dev/null
+++ b/sqlglot/dataframe/sql/column.py
@@ -0,0 +1,295 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.types import DataType
+from sqlglot.helper import flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrLiteral
+ from sqlglot.dataframe.sql.window import WindowSpec
+
+
+class Column:
+ def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ if isinstance(expression, Column):
+ expression = expression.expression # type: ignore
+ elif expression is None or not isinstance(expression, (str, exp.Expression)):
+ expression = self._lit(expression).expression # type: ignore
+ self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
+
+ def __repr__(self):
+ return repr(self.expression)
+
+ def __hash__(self):
+ return hash(self.expression)
+
+ def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore
+ return self.binary_op(exp.EQ, other)
+
+ def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore
+ return self.binary_op(exp.NEQ, other)
+
+ def __gt__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.GT, other)
+
+ def __ge__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.GTE, other)
+
+ def __lt__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.LT, other)
+
+ def __le__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.LTE, other)
+
+ def __and__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.And, other)
+
+ def __or__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Or, other)
+
+ def __mod__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Mod, other)
+
+ def __add__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Add, other)
+
+ def __sub__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Sub, other)
+
+ def __mul__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Mul, other)
+
+ def __truediv__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Div, other)
+
+ def __div__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Div, other)
+
+ def __neg__(self) -> Column:
+ return self.unary_op(exp.Neg)
+
+ def __radd__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Add, other)
+
+ def __rsub__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Sub, other)
+
+ def __rmul__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Mul, other)
+
+ def __rdiv__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Div, other)
+
+ def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Div, other)
+
+ def __rmod__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Mod, other)
+
+ def __pow__(self, power: ColumnOrLiteral, modulo=None):
+ return Column(exp.Pow(this=self.expression, power=Column(power).expression))
+
+ def __rpow__(self, power: ColumnOrLiteral):
+ return Column(exp.Pow(this=Column(power).expression, power=self.expression))
+
+ def __invert__(self):
+ return self.unary_op(exp.Not)
+
+ def __rand__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.And, other)
+
+ def __ror__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Or, other)
+
+ @classmethod
+ def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ return cls(value)
+
+ @classmethod
+ def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
+ return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
+
+ @classmethod
+ def _lit(cls, value: ColumnOrLiteral) -> Column:
+ if isinstance(value, dict):
+ columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
+ return cls(exp.Struct(expressions=columns))
+ return cls(exp.convert(value))
+
+ @classmethod
+ def invoke_anonymous_function(
+ cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
+ ) -> Column:
+ columns = [] if column is None else [cls.ensure_col(column)]
+ column_args = [cls.ensure_col(arg) for arg in args]
+ expressions = [x.expression for x in columns + column_args]
+ new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
+ return Column(new_expression)
+
+ @classmethod
+ def invoke_expression_over_column(
+ cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
+ ) -> Column:
+ ensured_column = None if column is None else cls.ensure_col(column)
+ new_expression = (
+ callable_expression(**kwargs)
+ if ensured_column is None
+ else callable_expression(this=ensured_column.column_expression, **kwargs)
+ )
+ return Column(new_expression)
+
+ def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
+ return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
+
+ def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
+ return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
+
+ def unary_op(self, klass: t.Callable, **kwargs) -> Column:
+ return Column(klass(this=self.column_expression, **kwargs))
+
+ @property
+ def is_alias(self):
+ return isinstance(self.expression, exp.Alias)
+
+ @property
+ def is_column(self):
+ return isinstance(self.expression, exp.Column)
+
+ @property
+ def column_expression(self) -> exp.Column:
+ return self.expression.unalias()
+
+ @property
+ def alias_or_name(self) -> str:
+ return self.expression.alias_or_name
+
+ @classmethod
+ def ensure_literal(cls, value) -> Column:
+ from sqlglot.dataframe.sql.functions import lit
+
+ if isinstance(value, cls):
+ value = value.expression
+ if not isinstance(value, exp.Literal):
+ return lit(value)
+ return Column(value)
+
+ def copy(self) -> Column:
+ return Column(self.expression.copy())
+
+ def set_table_name(self, table_name: str, copy=False) -> Column:
+ expression = self.expression.copy() if copy else self.expression
+ expression.set("table", exp.to_identifier(table_name))
+ return Column(expression)
+
+ def sql(self, **kwargs) -> Column:
+ return self.expression.sql(**{"dialect": "spark", **kwargs})
+
+ def alias(self, name: str) -> Column:
+ new_expression = exp.alias_(self.column_expression, name)
+ return Column(new_expression)
+
+ def asc(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
+ return Column(new_expression)
+
+ def desc(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
+ return Column(new_expression)
+
+ asc_nulls_first = asc
+
+ def asc_nulls_last(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
+ return Column(new_expression)
+
+ def desc_nulls_first(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
+ return Column(new_expression)
+
+ desc_nulls_last = desc
+
+ def when(self, condition: Column, value: t.Any) -> Column:
+ from sqlglot.dataframe.sql.functions import when
+
+ column_with_if = when(condition, value)
+ if not isinstance(self.expression, exp.Case):
+ return column_with_if
+ new_column = self.copy()
+ new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
+ return new_column
+
+ def otherwise(self, value: t.Any) -> Column:
+ from sqlglot.dataframe.sql.functions import lit
+
+ true_value = value if isinstance(value, Column) else lit(value)
+ new_column = self.copy()
+ new_column.expression.set("default", true_value.column_expression)
+ return new_column
+
+ def isNull(self) -> Column:
+ new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
+ return Column(new_expression)
+
+ def isNotNull(self) -> Column:
+ new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
+ return Column(new_expression)
+
+ def cast(self, dataType: t.Union[str, DataType]):
+ """
+ Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
+ Sqlglot doesn't currently replicate this class so it only accepts a string
+ """
+ if isinstance(dataType, DataType):
+ dataType = dataType.simpleString()
+ new_expression = exp.Cast(this=self.column_expression, to=dataType)
+ return Column(new_expression)
+
+ def startswith(self, value: t.Union[str, Column]) -> Column:
+ value = self._lit(value) if not isinstance(value, Column) else value
+ return self.invoke_anonymous_function(self, "STARTSWITH", value)
+
+ def endswith(self, value: t.Union[str, Column]) -> Column:
+ value = self._lit(value) if not isinstance(value, Column) else value
+ return self.invoke_anonymous_function(self, "ENDSWITH", value)
+
+ def rlike(self, regexp: str) -> Column:
+ return self.invoke_expression_over_column(
+ column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
+ )
+
+ def like(self, other: str):
+ return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
+
+ def ilike(self, other: str):
+ return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
+
+ def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
+ startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
+ length = self._lit(length) if not isinstance(length, Column) else length
+ return Column.invoke_expression_over_column(
+ self, exp.Substring, start=startPos.expression, length=length.expression
+ )
+
+ def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
+ columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [self._lit(x).expression for x in columns]
+ return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
+
+ def between(
+ self,
+ lowerBound: t.Union[ColumnOrLiteral],
+ upperBound: t.Union[ColumnOrLiteral],
+ ) -> Column:
+ lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
+ upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ return Column(
+ exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
+ )
+
+ def over(self, window: WindowSpec) -> Column:
+ window_expression = window.expression.copy()
+ window_expression.set("this", self.column_expression)
+ return Column(window_expression)
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
new file mode 100644
index 0000000..322dcf2
--- /dev/null
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -0,0 +1,730 @@
+from __future__ import annotations
+
+import functools
+import typing as t
+import zlib
+from copy import copy
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.group import GroupedData
+from sqlglot.dataframe.sql.normalize import normalize
+from sqlglot.dataframe.sql.operations import Operation, operation
+from sqlglot.dataframe.sql.readwriter import DataFrameWriter
+from sqlglot.dataframe.sql.transforms import replace_id_value
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.dataframe.sql.window import Window
+from sqlglot.helper import ensure_list, object_to_dict
+from sqlglot.optimizer import optimize as optimize_func
+from sqlglot.optimizer.qualify_columns import qualify_columns
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+JOIN_HINTS = {
+ "BROADCAST",
+ "BROADCASTJOIN",
+ "MAPJOIN",
+ "MERGE",
+ "SHUFFLEMERGE",
+ "MERGEJOIN",
+ "SHUFFLE_HASH",
+ "SHUFFLE_REPLICATE_NL",
+}
+
+
+class DataFrame:
+ def __init__(
+ self,
+ spark: SparkSession,
+ expression: exp.Select,
+ branch_id: t.Optional[str] = None,
+ sequence_id: t.Optional[str] = None,
+ last_op: Operation = Operation.INIT,
+ pending_hints: t.Optional[t.List[exp.Expression]] = None,
+ output_expression_container: t.Optional[OutputExpressionContainer] = None,
+ **kwargs,
+ ):
+ self.spark = spark
+ self.expression = expression
+ self.branch_id = branch_id or self.spark._random_branch_id
+ self.sequence_id = sequence_id or self.spark._random_sequence_id
+ self.last_op = last_op
+ self.pending_hints = pending_hints or []
+ self.output_expression_container = output_expression_container or exp.Select()
+
+ def __getattr__(self, column_name: str) -> Column:
+ return self[column_name]
+
+ def __getitem__(self, column_name: str) -> Column:
+ column_name = f"{self.branch_id}.{column_name}"
+ return Column(column_name)
+
+ def __copy__(self):
+ return self.copy()
+
+ @property
+ def sparkSession(self):
+ return self.spark
+
+ @property
+ def write(self):
+ return DataFrameWriter(self)
+
+ @property
+ def latest_cte_name(self) -> str:
+ if not self.expression.ctes:
+ from_exp = self.expression.args["from"]
+ if from_exp.alias_or_name:
+ return from_exp.alias_or_name
+ table_alias = from_exp.find(exp.TableAlias)
+ if not table_alias:
+ raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
+ return table_alias.alias_or_name
+ return self.expression.ctes[-1].alias
+
+ @property
+ def pending_join_hints(self):
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
+
+ @property
+ def pending_partition_hints(self):
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
+
+ @property
+ def columns(self) -> t.List[str]:
+ return self.expression.named_selects
+
+ @property
+ def na(self) -> DataFrameNaFunctions:
+ return DataFrameNaFunctions(self)
+
+ def _replace_cte_names_with_hashes(self, expression: exp.Select):
+ expression = expression.copy()
+ ctes = expression.ctes
+ replacement_mapping = {}
+ for cte in ctes:
+ old_name_id = cte.args["alias"].this
+ new_hashed_id = exp.to_identifier(
+ self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
+ )
+ replacement_mapping[old_name_id] = new_hashed_id
+ cte.set("alias", exp.TableAlias(this=new_hashed_id))
+ expression = expression.transform(replace_id_value, replacement_mapping)
+ return expression
+
+ def _create_cte_from_expression(
+ self,
+ expression: exp.Expression,
+ branch_id: t.Optional[str] = None,
+ sequence_id: t.Optional[str] = None,
+ **kwargs,
+ ) -> t.Tuple[exp.CTE, str]:
+ name = self.spark._random_name
+ expression_to_cte = expression.copy()
+ expression_to_cte.set("with", None)
+ cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
+ cte.set("branch_id", branch_id or self.branch_id)
+ cte.set("sequence_id", sequence_id or self.sequence_id)
+ return cte, name
+
+ def _ensure_list_of_columns(
+ self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
+ ) -> t.List[Column]:
+ columns = ensure_list(cols)
+ columns = Column.ensure_cols(columns)
+ return columns
+
+ def _ensure_and_normalize_cols(self, cols):
+ cols = self._ensure_list_of_columns(cols)
+ normalize(self.spark, self.expression, cols)
+ return cols
+
+ def _ensure_and_normalize_col(self, col):
+ col = Column.ensure_col(col)
+ normalize(self.spark, self.expression, col)
+ return col
+
+ def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
+ df = self._resolve_pending_hints()
+ sequence_id = sequence_id or df.sequence_id
+ expression = df.expression.copy()
+ cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
+ new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
+ sel_columns = df._get_outer_select_columns(cte_expression)
+ new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
+ return df.copy(expression=new_expression, sequence_id=sequence_id)
+
+ def _resolve_pending_hints(self) -> DataFrame:
+ df = self.copy()
+ if not self.pending_hints:
+ return df
+ expression = df.expression
+ hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
+ for hint in df.pending_partition_hints:
+ hint_expression.args.get("expressions").append(hint)
+ df.pending_hints.remove(hint)
+
+ join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
+ if join_aliases:
+ for hint in df.pending_join_hints:
+ for sequence_id_expression in hint.expressions:
+ sequence_id_or_name = sequence_id_expression.alias_or_name
+ sequence_ids_to_match = [sequence_id_or_name]
+ if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
+ matching_ctes = [
+ cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
+ ]
+ for matching_cte in matching_ctes:
+ if matching_cte.alias_or_name in join_aliases:
+ sequence_id_expression.set("this", matching_cte.args["alias"].this)
+ df.pending_hints.remove(hint)
+ break
+ hint_expression.args.get("expressions").append(hint)
+ if hint_expression.expressions:
+ expression.set("hint", hint_expression)
+ return df
+
+ def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
+ hint_name = hint_name.upper()
+ hint_expression = (
+ exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
+ if hint_name in JOIN_HINTS
+ else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
+ )
+ new_df = self.copy()
+ new_df.pending_hints.append(hint_expression)
+ return new_df
+
+ def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
+ other_df = other._convert_leaf_to_cte()
+ base_expression = self.expression.copy()
+ base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
+ all_ctes = base_expression.ctes
+ other_df.expression.set("with", None)
+ base_expression.set("with", None)
+ operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
+ operation.set("with", exp.With(expressions=all_ctes))
+ return self.copy(expression=operation)._convert_leaf_to_cte()
+
+ def _cache(self, storage_level: str):
+ df = self._convert_leaf_to_cte()
+ df.expression.ctes[-1].set("cache_storage_level", storage_level)
+ return df
+
+ @classmethod
+ def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
+ expression = expression.copy()
+ with_expression = expression.args.get("with")
+ if with_expression:
+ existing_ctes = with_expression.expressions
+ existsing_cte_names = {x.alias_or_name for x in existing_ctes}
+ for cte in ctes:
+ if cte.alias_or_name not in existsing_cte_names:
+ existing_ctes.append(cte)
+ else:
+ existing_ctes = ctes
+ expression.set("with", exp.With(expressions=existing_ctes))
+ return expression
+
+ @classmethod
+ def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
+ expression = item.expression if isinstance(item, DataFrame) else item
+ return [Column(x) for x in expression.find(exp.Select).expressions]
+
+ @classmethod
+ def _create_hash_from_expression(cls, expression: exp.Select):
+ value = expression.sql(dialect="spark").encode("utf-8")
+ return f"t{zlib.crc32(value)}"[:6]
+
+ def _get_select_expressions(
+ self,
+ ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
+ select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
+ main_select_ctes: t.List[exp.CTE] = []
+ for cte in self.expression.ctes:
+ cache_storage_level = cte.args.get("cache_storage_level")
+ if cache_storage_level:
+ select_expression = cte.this.copy()
+ select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
+ select_expression.set("cte_alias_name", cte.alias_or_name)
+ select_expression.set("cache_storage_level", cache_storage_level)
+ select_expressions.append((exp.Cache, select_expression))
+ else:
+ main_select_ctes.append(cte)
+ main_select = self.expression.copy()
+ if main_select_ctes:
+ main_select.set("with", exp.With(expressions=main_select_ctes))
+ expression_select_pair = (type(self.output_expression_container), main_select)
+ select_expressions.append(expression_select_pair) # type: ignore
+ return select_expressions
+
+ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
+ df = self._resolve_pending_hints()
+ select_expressions = df._get_select_expressions()
+ output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
+ replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
+ for expression_type, select_expression in select_expressions:
+ select_expression = select_expression.transform(replace_id_value, replacement_mapping)
+ if optimize:
+ select_expression = optimize_func(select_expression)
+ select_expression = df._replace_cte_names_with_hashes(select_expression)
+ expression: t.Union[exp.Select, exp.Cache, exp.Drop]
+ if expression_type == exp.Cache:
+ cache_table_name = df._create_hash_from_expression(select_expression)
+ cache_table = exp.to_table(cache_table_name)
+ original_alias_name = select_expression.args["cte_alias_name"]
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
+ sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
+ cache_storage_level = select_expression.args["cache_storage_level"]
+ options = [
+ exp.Literal.string("storageLevel"),
+ exp.Literal.string(cache_storage_level),
+ ]
+ expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
+ # We will drop the "view" if it exists before running the cache table
+ output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
+ elif expression_type == exp.Create:
+ expression = df.output_expression_container.copy()
+ expression.set("expression", select_expression)
+ elif expression_type == exp.Insert:
+ expression = df.output_expression_container.copy()
+ select_without_ctes = select_expression.copy()
+ select_without_ctes.set("with", None)
+ expression.set("expression", select_without_ctes)
+ if select_expression.ctes:
+ expression.set("with", exp.With(expressions=select_expression.ctes))
+ elif expression_type == exp.Select:
+ expression = select_expression
+ else:
+ raise ValueError(f"Invalid expression type: {expression_type}")
+ output_expressions.append(expression)
+
+ return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
+
+ def copy(self, **kwargs) -> DataFrame:
+ return DataFrame(**object_to_dict(self, **kwargs))
+
+ @operation(Operation.SELECT)
+ def select(self, *cols, **kwargs) -> DataFrame:
+ cols = self._ensure_and_normalize_cols(cols)
+ kwargs["append"] = kwargs.get("append", False)
+ if self.expression.args.get("joins"):
+ ambiguous_cols = [col for col in cols if not col.column_expression.table]
+ if ambiguous_cols:
+ join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
+ cte_names_in_join = [x.this for x in join_table_identifiers]
+ for ambiguous_col in ambiguous_cols:
+ ctes_with_column = [
+ cte
+ for cte in self.expression.ctes
+ if cte.alias_or_name in cte_names_in_join
+ and ambiguous_col.alias_or_name in cte.this.named_selects
+ ]
+ # If the select column does not specify a table and there is a join
+ # then we assume they are referring to the left table
+ if len(ctes_with_column) > 1:
+ table_identifier = self.expression.args["from"].args["expressions"][0].this
+ else:
+ table_identifier = ctes_with_column[0].args["alias"].this
+ ambiguous_col.expression.set("table", table_identifier)
+ expression = self.expression.select(*[x.expression for x in cols], **kwargs)
+ qualify_columns(expression, sqlglot.schema)
+ return self.copy(expression=expression, **kwargs)
+
+ @operation(Operation.NO_OP)
+ def alias(self, name: str, **kwargs) -> DataFrame:
+ new_sequence_id = self.spark._random_sequence_id
+ df = self.copy()
+ for join_hint in df.pending_join_hints:
+ for expression in join_hint.expressions:
+ if expression.alias_or_name == self.sequence_id:
+ expression.set("this", Column.ensure_col(new_sequence_id).expression)
+ df.spark._add_alias_to_mapping(name, new_sequence_id)
+ return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
+
+ @operation(Operation.WHERE)
+ def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
+ col = self._ensure_and_normalize_col(column)
+ return self.copy(expression=self.expression.where(col.expression))
+
+ filter = where
+
+ @operation(Operation.GROUP_BY)
+ def groupBy(self, *cols, **kwargs) -> GroupedData:
+ columns = self._ensure_and_normalize_cols(cols)
+ return GroupedData(self, columns, self.last_op)
+
+ @operation(Operation.SELECT)
+ def agg(self, *exprs, **kwargs) -> DataFrame:
+ cols = self._ensure_and_normalize_cols(exprs)
+ return self.groupBy().agg(*cols)
+
+ @operation(Operation.FROM)
+ def join(
+ self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
+ ) -> DataFrame:
+ other_df = other_df._convert_leaf_to_cte()
+ pre_join_self_latest_cte_name = self.latest_cte_name
+ columns = self._ensure_and_normalize_cols(on)
+ join_type = how.replace("_", " ")
+ if isinstance(columns[0].expression, exp.Column):
+ join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
+ join_clause = functools.reduce(
+ lambda x, y: x & y,
+ [
+ col.copy().set_table_name(pre_join_self_latest_cte_name)
+ == col.copy().set_table_name(other_df.latest_cte_name)
+ for col in columns
+ ],
+ )
+ else:
+ if len(columns) > 1:
+ columns = [functools.reduce(lambda x, y: x & y, columns)]
+ join_clause = columns[0]
+ join_columns = [
+ Column(x).set_table_name(pre_join_self_latest_cte_name)
+ if i % 2 == 0
+ else Column(x).set_table_name(other_df.latest_cte_name)
+ for i, x in enumerate(join_clause.expression.find_all(exp.Column))
+ ]
+ self_columns = [
+ column.set_table_name(pre_join_self_latest_cte_name, copy=True)
+ for column in self._get_outer_select_columns(self)
+ ]
+ other_columns = [
+ column.set_table_name(other_df.latest_cte_name, copy=True)
+ for column in self._get_outer_select_columns(other_df)
+ ]
+ column_value_mapping = {
+ column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
+ for column in other_columns + self_columns + join_columns
+ }
+ all_columns = [
+ column_value_mapping[name]
+ for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
+ ]
+ new_df = self.copy(
+ expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
+ )
+ new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
+ new_df.pending_hints.extend(other_df.pending_hints)
+ new_df = new_df.select.__wrapped__(new_df, *all_columns)
+ return new_df
+
+ @operation(Operation.ORDER_BY)
+ def orderBy(
+ self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
+ ) -> DataFrame:
+ """
+ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
+ has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
+ is unlikely to come up.
+ """
+ columns = self._ensure_and_normalize_cols(cols)
+ pre_ordered_col_indexes = [
+ x
+ for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
+ if x is not None
+ ]
+ if ascending is None:
+ ascending = [True] * len(columns)
+ elif not isinstance(ascending, list):
+ ascending = [ascending] * len(columns)
+ ascending = [bool(x) for i, x in enumerate(ascending)]
+ assert len(columns) == len(
+ ascending
+ ), "The length of items in ascending must equal the number of columns provided"
+ col_and_ascending = list(zip(columns, ascending))
+ order_by_columns = [
+ exp.Ordered(this=col.expression, desc=not asc)
+ if i not in pre_ordered_col_indexes
+ else columns[i].column_expression
+ for i, (col, asc) in enumerate(col_and_ascending)
+ ]
+ return self.copy(expression=self.expression.order_by(*order_by_columns))
+
+ sort = orderBy
+
+ @operation(Operation.FROM)
+ def union(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Union, other, False)
+
+ unionAll = union
+
+ @operation(Operation.FROM)
+ def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
+ l_columns = self.columns
+ r_columns = other.columns
+ if not allowMissingColumns:
+ l_expressions = l_columns
+ r_expressions = l_columns
+ else:
+ l_expressions = []
+ r_expressions = []
+ r_columns_unused = copy(r_columns)
+ for l_column in l_columns:
+ l_expressions.append(l_column)
+ if l_column in r_columns:
+ r_expressions.append(l_column)
+ r_columns_unused.remove(l_column)
+ else:
+ r_expressions.append(exp.alias_(exp.Null(), l_column))
+ for r_column in r_columns_unused:
+ l_expressions.append(exp.alias_(exp.Null(), r_column))
+ r_expressions.append(r_column)
+ r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ l_df = self.copy()
+ if allowMissingColumns:
+ l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
+ return l_df._set_operation(exp.Union, r_df, False)
+
+ @operation(Operation.FROM)
+ def intersect(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Intersect, other, True)
+
+ @operation(Operation.FROM)
+ def intersectAll(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Intersect, other, False)
+
+ @operation(Operation.FROM)
+ def exceptAll(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Except, other, False)
+
+ @operation(Operation.SELECT)
+ def distinct(self) -> DataFrame:
+ return self.copy(expression=self.expression.distinct())
+
+ @operation(Operation.SELECT)
+ def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
+ if not subset:
+ return self.distinct()
+ column_names = ensure_list(subset)
+ window = Window.partitionBy(*column_names).orderBy(*column_names)
+ return (
+ self.copy()
+ .withColumn("row_num", F.row_number().over(window))
+ .where(F.col("row_num") == F.lit(1))
+ .drop("row_num")
+ )
+
+ @operation(Operation.FROM)
+ def dropna(
+ self,
+ how: str = "any",
+ thresh: t.Optional[int] = None,
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ minimum_non_null = thresh or 0 # will be determined later if thresh is null
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ if subset:
+ null_check_columns = self._ensure_and_normalize_cols(subset)
+ else:
+ null_check_columns = all_columns
+ if thresh is None:
+ minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
+ else:
+ minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
+ if minimum_num_nulls > len(null_check_columns):
+ raise RuntimeError(
+ f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
+ f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
+ )
+ if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
+ nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
+ num_nulls = nulls_added_together.alias("num_nulls")
+ new_df = new_df.select(num_nulls, append=True)
+ filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
+ final_df = filtered_df.select(*all_columns)
+ return final_df
+
+ @operation(Operation.FROM)
+ def fillna(
+ self,
+ value: t.Union[ColumnLiterals],
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ """
+ Functionality Difference: If you provide a value to replace a null and that type conflicts
+ with the type of the column then PySpark will just ignore your replacement.
+ This will try to cast them to be the same in some cases. So they won't always match.
+ Best to not mix types so make sure replacement is the same type as the column
+
+ Possibility for improvement: Use `typeof` function to get the type of the column
+ and check if it matches the type of the value provided. If not then make it null.
+ """
+ from sqlglot.dataframe.sql.functions import lit
+
+ values = None
+ columns = None
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
+ if isinstance(value, dict):
+ values = value.values()
+ columns = self._ensure_and_normalize_cols(list(value))
+ if not columns:
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
+ if not values:
+ values = [value] * len(columns)
+ value_columns = [lit(value) for value in values]
+
+ null_replacement_mapping = {
+ column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
+ for column, value in zip(columns, value_columns)
+ }
+ null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
+ null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
+ new_df = new_df.select(*null_replacement_columns)
+ return new_df
+
+ @operation(Operation.FROM)
+ def replace(
+ self,
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ ) -> DataFrame:
+ from sqlglot.dataframe.sql.functions import lit
+
+ old_values = None
+ subset = ensure_list(subset)
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
+
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
+ if isinstance(to_replace, dict):
+ old_values = list(to_replace)
+ new_values = list(to_replace.values())
+ elif not old_values and isinstance(to_replace, list):
+ assert isinstance(value, list), "value must be a list since the replacements are a list"
+ assert len(to_replace) == len(value), "the replacements and values must be the same length"
+ old_values = to_replace
+ new_values = value
+ else:
+ old_values = [to_replace] * len(columns)
+ new_values = [value] * len(columns)
+ old_values = [lit(value) for value in old_values]
+ new_values = [lit(value) for value in new_values]
+
+ replacement_mapping = {}
+ for column in columns:
+ expression = Column(None)
+ for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
+ if i == 0:
+ expression = F.when(column == old_value, new_value)
+ else:
+ expression = expression.when(column == old_value, new_value) # type: ignore
+ replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
+ column.expression.alias_or_name
+ )
+
+ replacement_mapping = {**all_column_mapping, **replacement_mapping}
+ replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
+ new_df = new_df.select(*replacement_columns)
+ return new_df
+
+ @operation(Operation.SELECT)
+ def withColumn(self, colName: str, col: Column) -> DataFrame:
+ col = self._ensure_and_normalize_col(col)
+ existing_col_names = self.expression.named_selects
+ existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
+ if existing_col_index:
+ expression = self.expression.copy()
+ expression.expressions[existing_col_index] = col.expression
+ return self.copy(expression=expression)
+ return self.copy().select(col.alias(colName), append=True)
+
+ @operation(Operation.SELECT)
+ def withColumnRenamed(self, existing: str, new: str):
+ expression = self.expression.copy()
+ existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
+ if not existing_columns:
+ raise ValueError("Tried to rename a column that doesn't exist")
+ for existing_column in existing_columns:
+ if isinstance(existing_column, exp.Column):
+ existing_column.replace(exp.alias_(existing_column.copy(), new))
+ else:
+ existing_column.set("alias", exp.to_identifier(new))
+ return self.copy(expression=expression)
+
+ @operation(Operation.SELECT)
+ def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
+ all_columns = self._get_outer_select_columns(self.expression)
+ drop_cols = self._ensure_and_normalize_cols(cols)
+ new_columns = [
+ col
+ for col in all_columns
+ if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
+ ]
+ return self.copy().select(*new_columns, append=False)
+
+ @operation(Operation.LIMIT)
+ def limit(self, num: int) -> DataFrame:
+ return self.copy(expression=self.expression.limit(num))
+
+ @operation(Operation.NO_OP)
+ def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
+ parameter_list = ensure_list(parameters)
+ parameter_columns = (
+ self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
+ )
+ return self._hint(name, parameter_columns)
+
+ @operation(Operation.NO_OP)
+ def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
+ num_partitions = Column.ensure_cols(ensure_list(numPartitions))
+ columns = self._ensure_and_normalize_cols(cols)
+ args = num_partitions + columns
+ return self._hint("repartition", args)
+
+ @operation(Operation.NO_OP)
+ def coalesce(self, numPartitions: int) -> DataFrame:
+ num_partitions = Column.ensure_cols([numPartitions])
+ return self._hint("coalesce", num_partitions)
+
+ @operation(Operation.NO_OP)
+ def cache(self) -> DataFrame:
+ return self._cache(storage_level="MEMORY_AND_DISK")
+
+ @operation(Operation.NO_OP)
+ def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
+ """
+ Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
+ """
+ return self._cache(storageLevel)
+
+
+class DataFrameNaFunctions:
+ def __init__(self, df: DataFrame):
+ self.df = df
+
+ def drop(
+ self,
+ how: str = "any",
+ thresh: t.Optional[int] = None,
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
+
+ def fill(
+ self,
+ value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.fillna(value=value, subset=subset)
+
+ def replace(
+ self,
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.replace(to_replace=to_replace, value=value, subset=subset)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
new file mode 100644
index 0000000..4c6de30
--- /dev/null
+++ b/sqlglot/dataframe/sql/functions.py
@@ -0,0 +1,1258 @@
+from __future__ import annotations
+
+import typing as t
+from inspect import signature
+
+from sqlglot import expressions as glotexp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.helper import ensure_list
+from sqlglot.helper import flatten as _flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+
+def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
+ return Column(column_name)
+
+
+def lit(value: t.Optional[t.Any] = None) -> Column:
+ if isinstance(value, str):
+ return Column(glotexp.Literal.string(str(value)))
+ return Column(value)
+
+
+def greatest(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def least(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(x) for x in [col] + list(cols)]
+ return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns])))
+
+
+def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
+ return count_distinct(col, *cols)
+
+
+def when(condition: Column, value: t.Any) -> Column:
+ true_value = value if isinstance(value, Column) else lit(value)
+ return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
+
+
+def asc(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc()
+
+
+def desc(col: ColumnOrName):
+ return Column.ensure_col(col).desc()
+
+
+def broadcast(df: DataFrame) -> DataFrame:
+ return df.hint("broadcast")
+
+
+def sqrt(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Sqrt)
+
+
+def abs(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Abs)
+
+
+def max(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Max)
+
+
+def min(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Min)
+
+
+def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAX_BY", ord)
+
+
+def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MIN_BY", ord)
+
+
+def count(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Count)
+
+
+def sum(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Sum)
+
+
+def avg(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Avg)
+
+
+def mean(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MEAN")
+
+
+def sumDistinct(col: ColumnOrName) -> Column:
+ return sum_distinct(col)
+
+
+def sum_distinct(col: ColumnOrName) -> Column:
+ raise NotImplementedError("Sum distinct is not currently implemented")
+
+
+def product(col: ColumnOrName) -> Column:
+ raise NotImplementedError("Product is not currently implemented")
+
+
+def acos(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ACOS")
+
+
+def acosh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ACOSH")
+
+
+def asin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ASIN")
+
+
+def asinh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ASINH")
+
+
+def atan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ATAN")
+
+
+def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "ATAN2", col2)
+
+
+def atanh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ATANH")
+
+
+def cbrt(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "CBRT")
+
+
+def ceil(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Ceil)
+
+
+def cos(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COS")
+
+
+def cosh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COSH")
+
+
+def cot(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COT")
+
+
+def csc(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "CSC")
+
+
+def exp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Exp)
+
+
+def expm1(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "EXPM1")
+
+
+def floor(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Floor)
+
+
+def log10(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Log10)
+
+
+def log1p(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LOG1P")
+
+
+def log2(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Log2)
+
+
+def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
+ if arg2 is None:
+ return Column.invoke_expression_over_column(arg1, glotexp.Ln)
+ return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression)
+
+
+def rint(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RINT")
+
+
+def sec(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SEC")
+
+
+def signum(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SIGNUM")
+
+
+def sin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SIN")
+
+
+def sinh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SINH")
+
+
+def tan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TAN")
+
+
+def tanh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TANH")
+
+
+def toDegrees(col: ColumnOrName) -> Column:
+ return degrees(col)
+
+
+def degrees(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DEGREES")
+
+
+def toRadians(col: ColumnOrName) -> Column:
+ return radians(col)
+
+
+def radians(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RADIANS")
+
+
+def bitwiseNOT(col: ColumnOrName) -> Column:
+ return bitwise_not(col)
+
+
+def bitwise_not(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.BitwiseNot)
+
+
+def asc_nulls_first(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc_nulls_first()
+
+
+def asc_nulls_last(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc_nulls_last()
+
+
+def desc_nulls_first(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).desc_nulls_first()
+
+
+def desc_nulls_last(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).desc_nulls_last()
+
+
+def stddev(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Stddev)
+
+
+def stddev_samp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.StddevSamp)
+
+
+def stddev_pop(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.StddevPop)
+
+
+def variance(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Variance)
+
+
+def var_samp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Variance)
+
+
+def var_pop(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.VariancePop)
+
+
+def skewness(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SKEWNESS")
+
+
+def kurtosis(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "KURTOSIS")
+
+
+def collect_list(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArrayAgg)
+
+
+def collect_set(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.SetAgg)
+
+
+def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "HYPOT", col2)
+
+
+def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "POW", col2)
+
+
+def row_number() -> Column:
+ return Column(glotexp.Anonymous(this="ROW_NUMBER"))
+
+
+def dense_rank() -> Column:
+ return Column(glotexp.Anonymous(this="DENSE_RANK"))
+
+
+def rank() -> Column:
+ return Column(glotexp.Anonymous(this="RANK"))
+
+
+def cume_dist() -> Column:
+ return Column(glotexp.Anonymous(this="CUME_DIST"))
+
+
+def percent_rank() -> Column:
+ return Column(glotexp.Anonymous(this="PERCENT_RANK"))
+
+
+def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
+ return approx_count_distinct(col, rsd)
+
+
+def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
+ if rsd is None:
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression)
+
+
+def coalesce(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "CORR", col2)
+
+
+def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
+
+
+def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
+
+
+def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
+ if ignorenulls is not None:
+ return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
+ return Column.invoke_anonymous_function(col, "FIRST")
+
+
+def grouping_id(*cols: ColumnOrName) -> Column:
+ if not cols:
+ return Column.invoke_anonymous_function(None, "GROUPING_ID")
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "GROUPING_ID")
+ return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:])
+
+
+def input_file_name() -> Column:
+ return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME")
+
+
+def isnan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ISNAN")
+
+
+def isnull(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ISNULL")
+
+
+def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
+ if ignorenulls is not None:
+ return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
+ return Column.invoke_anonymous_function(col, "LAST")
+
+
+def monotonically_increasing_id() -> Column:
+ return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID")
+
+
+def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "NANVL", col2)
+
+
+def percentile_approx(
+ col: ColumnOrName,
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
+ accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None,
+) -> Column:
+ if accuracy:
+ return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy)
+ return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage)
+
+
+def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
+ return Column.invoke_anonymous_function(seed, "RAND")
+
+
+def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
+ return Column.invoke_anonymous_function(seed, "RANDN")
+
+
+def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
+ if scale is not None:
+ return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale))
+ return Column.invoke_expression_over_column(col, glotexp.Round)
+
+
+def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
+ if scale is not None:
+ return Column.invoke_anonymous_function(col, "BROUND", scale)
+ return Column.invoke_anonymous_function(col, "BROUND")
+
+
+def shiftleft(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_expression_over_column(
+ col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression
+ )
+
+
+def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
+ return shiftleft(col, numBits)
+
+
+def shiftright(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_expression_over_column(
+ col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression
+ )
+
+
+def shiftRight(col: ColumnOrName, numBits: int) -> Column:
+ return shiftright(col, numBits)
+
+
+def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits)
+
+
+def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column:
+ return shiftrightunsigned(col, numBits)
+
+
+def expr(str: str) -> Column:
+ return Column(str)
+
+
+def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
+ columns = ensure_list(col) + list(cols)
+ expressions = [Column.ensure_col(column).expression for column in columns]
+ return Column(glotexp.Struct(expressions=expressions))
+
+
+def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
+ return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase)
+
+
+def factorial(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "FACTORIAL")
+
+
+def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
+ if default is not None:
+ return Column.invoke_anonymous_function(col, "LAG", offset, default)
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "LAG", offset)
+ return Column.invoke_anonymous_function(col, "LAG")
+
+
+def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
+ if default is not None:
+ return Column.invoke_anonymous_function(col, "LEAD", offset, default)
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "LEAD", offset)
+ return Column.invoke_anonymous_function(col, "LEAD")
+
+
+def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
+ if ignoreNulls is not None:
+ raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
+ return Column.invoke_anonymous_function(col, "NTH_VALUE")
+
+
+def ntile(n: int) -> Column:
+ return Column.invoke_anonymous_function(None, "NTILE", n)
+
+
+def current_date() -> Column:
+ return Column.invoke_expression_over_column(None, glotexp.CurrentDate)
+
+
+def current_timestamp() -> Column:
+ return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp)
+
+
+def date_format(col: ColumnOrName, format: str) -> Column:
+ return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format))
+
+
+def year(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Year)
+
+
+def quarter(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "QUARTER")
+
+
+def month(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Month)
+
+
+def dayofweek(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFWEEK")
+
+
+def dayofmonth(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFMONTH")
+
+
+def dayofyear(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFYEAR")
+
+
+def hour(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "HOUR")
+
+
+def minute(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MINUTE")
+
+
+def second(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SECOND")
+
+
+def weekofyear(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
+
+
+def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day)
+
+
+def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression)
+
+
+def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression)
+
+
+def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression)
+
+
+def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
+
+
+def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
+ if roundOff is None:
+ return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
+ return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
+
+
+def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "TO_DATE", lit(format))
+ return Column.invoke_anonymous_function(col, "TO_DATE")
+
+
+def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format))
+ return Column.invoke_anonymous_function(col, "TO_TIMESTAMP")
+
+
+def trunc(col: ColumnOrName, format: str) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression)
+
+
+def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression)
+
+
+def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
+ return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek))
+
+
+def last_day(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LAST_DAY")
+
+
+def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format))
+ return Column.invoke_anonymous_function(col, "FROM_UNIXTIME")
+
+
+def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format))
+ return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP")
+
+
+def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
+ tz_column = tz if isinstance(tz, Column) else lit(tz)
+ return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column)
+
+
+def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
+ tz_column = tz if isinstance(tz, Column) else lit(tz)
+ return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
+
+
+def timestamp_seconds(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS")
+
+
+def window(
+ timeColumn: ColumnOrName,
+ windowDuration: str,
+ slideDuration: t.Optional[str] = None,
+ startTime: t.Optional[str] = None,
+) -> Column:
+ if slideDuration is not None and startTime is not None:
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
+ )
+ if slideDuration is not None:
+ return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
+ if startTime is not None:
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
+ )
+ return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration))
+
+
+def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column:
+ gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration)
+ return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column)
+
+
+def crc32(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "CRC32")
+
+
+def md5(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "MD5")
+
+
+def sha1(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "SHA1")
+
+
+def sha2(col: ColumnOrName, numBits: int) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ num_bits = lit(numBits)
+ return Column.invoke_anonymous_function(column, "SHA2", num_bits)
+
+
+def hash(*cols: ColumnOrName) -> Column:
+ args = cols[1:] if len(cols) > 1 else []
+ return Column.invoke_anonymous_function(cols[0], "HASH", *args)
+
+
+def xxhash64(*cols: ColumnOrName) -> Column:
+ args = cols[1:] if len(cols) > 1 else []
+ return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args)
+
+
+def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column:
+ if errorMsg is not None:
+ error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
+ return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col)
+ return Column.invoke_anonymous_function(col, "ASSERT_TRUE")
+
+
+def raise_error(errorMsg: ColumnOrName) -> Column:
+ error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
+ return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR")
+
+
+def upper(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Upper)
+
+
+def lower(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Lower)
+
+
+def ascii(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "ASCII")
+
+
+def base64(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "BASE64")
+
+
+def unbase64(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "UNBASE64")
+
+
+def ltrim(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LTRIM")
+
+
+def rtrim(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RTRIM")
+
+
+def trim(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Trim)
+
+
+def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
+ columns = [Column(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)]
+ )
+
+
+def decode(col: ColumnOrName, charset: str) -> Column:
+ return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
+
+
+def encode(col: ColumnOrName, charset: str) -> Column:
+ return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
+
+
+def format_number(col: ColumnOrName, d: int) -> Column:
+ return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d))
+
+
+def format_string(format: str, *cols: ColumnOrName) -> Column:
+ format_col = lit(format)
+ columns = [Column.ensure_col(x) for x in cols]
+ return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns)
+
+
+def instr(col: ColumnOrName, substr: str) -> Column:
+ return Column.invoke_anonymous_function(col, "INSTR", lit(substr))
+
+
+def overlay(
+ src: ColumnOrName,
+ replace: ColumnOrName,
+ pos: t.Union[ColumnOrName, int],
+ len: t.Optional[t.Union[ColumnOrName, int]] = None,
+) -> Column:
+ if len is not None:
+ return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len)
+ return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos)
+
+
+def sentences(
+ string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
+) -> Column:
+ if language is not None and country is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
+ if language is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", language)
+ if country is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country)
+ return Column.invoke_anonymous_function(string, "SENTENCES")
+
+
+def substring(str: ColumnOrName, pos: int, len: int) -> Column:
+ return Column.ensure_col(str).substr(pos, len)
+
+
+def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
+ return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count))
+
+
+def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(
+ left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression
+ )
+
+
+def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
+ substr_col = lit(substr)
+ pos_column = lit(pos)
+ str_column = Column.ensure_col(str)
+ if pos is not None:
+ return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column)
+ return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column)
+
+
+def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
+ return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad))
+
+
+def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
+ return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad))
+
+
+def repeat(col: ColumnOrName, n: int) -> Column:
+ return Column.invoke_anonymous_function(col, "REPEAT", n)
+
+
+def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
+ if limit is not None:
+ return Column.invoke_expression_over_column(
+ str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression
+ )
+ return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression)
+
+
+def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
+ if idx is not None:
+ return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx)
+ return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern))
+
+
+def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
+ return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
+
+
+def initcap(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Initcap)
+
+
+def soundex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SOUNDEX")
+
+
+def bin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "BIN")
+
+
+def hex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "HEX")
+
+
+def unhex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "UNHEX")
+
+
+def length(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Length)
+
+
+def octet_length(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "OCTET_LENGTH")
+
+
+def bit_length(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "BIT_LENGTH")
+
+
+def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
+ return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace))
+
+
+def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ cols = [Column.ensure_col(col).expression for col in cols] # type: ignore
+ return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols)
+
+
+def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ return Column.invoke_expression_over_column(
+ None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
+ )
+
+
+def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2)
+
+
+def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
+
+
+def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
+
+
+def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
+ start_col = start if isinstance(start, Column) else lit(start)
+ length_col = length if isinstance(length, Column) else lit(length)
+ return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
+
+
+def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
+ if null_replacement is not None:
+ return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
+ return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
+
+
+def concat(*cols: ColumnOrName) -> Column:
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "CONCAT")
+ return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
+
+
+def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col)
+
+
+def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col)
+
+
+def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col)
+
+
+def array_distinct(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT")
+
+
+def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2))
+
+
+def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2))
+
+
+def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2))
+
+
+def explode(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Explode)
+
+
+def posexplode(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Posexplode)
+
+
+def explode_outer(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "EXPLODE_OUTER")
+
+
+def posexplode_outer(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER")
+
+
+def get_json_object(col: ColumnOrName, path: str) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression)
+
+
+def json_tuple(col: ColumnOrName, *fields: str) -> Column:
+ return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields])
+
+
+def from_json(
+ col: ColumnOrName,
+ schema: t.Union[Column, str],
+ options: t.Optional[t.Dict[str, str]] = None,
+) -> Column:
+ schema = schema if isinstance(schema, Column) else lit(schema)
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col)
+ return Column.invoke_anonymous_function(col, "FROM_JSON", schema)
+
+
+def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "TO_JSON", options_col)
+ return Column.invoke_anonymous_function(col, "TO_JSON")
+
+
+def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col)
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON")
+
+
+def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col)
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV")
+
+
+def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "TO_CSV", options_col)
+ return Column.invoke_anonymous_function(col, "TO_CSV")
+
+
+def size(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArraySize)
+
+
+def array_min(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_MIN")
+
+
+def array_max(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_MAX")
+
+
+def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
+ if asc is not None:
+ return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc))
+ return Column.invoke_anonymous_function(col, "SORT_ARRAY")
+
+
+def array_sort(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArraySort)
+
+
+def shuffle(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SHUFFLE")
+
+
+def reverse(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "REVERSE")
+
+
+def flatten(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "FLATTEN")
+
+
+def map_keys(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_KEYS")
+
+
+def map_values(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_VALUES")
+
+
+def map_entries(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_ENTRIES")
+
+
+def map_from_entries(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES")
+
+
+def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
+ count_col = count if isinstance(count, Column) else lit(count)
+ return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col)
+
+
+def array_zip(*cols: ColumnOrName) -> Column:
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP")
+ return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:])
+
+
+def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ if len(columns) == 1:
+ return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT")
+ return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
+
+
+def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
+ if step is not None:
+ return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
+ return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
+
+
+def from_csv(
+ col: ColumnOrName,
+ schema: t.Union[Column, str],
+ options: t.Optional[t.Dict[str, str]] = None,
+) -> Column:
+ schema = schema if isinstance(schema, Column) else lit(schema)
+ if options is not None:
+ option_cols = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols)
+ return Column.invoke_anonymous_function(col, "FROM_CSV", schema)
+
+
+def aggregate(
+ col: ColumnOrName,
+ initialValue: ColumnOrName,
+ merge: t.Callable[[Column, Column], Column],
+ finish: t.Optional[t.Callable[[Column], Column]] = None,
+ accumulator_name: str = "acc",
+ target_row_name: str = "x",
+) -> Column:
+ merge_exp = glotexp.Lambda(
+ this=merge(Column(accumulator_name), Column(target_row_name)).expression,
+ expressions=[
+ glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)),
+ glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)),
+ ],
+ )
+ if finish is not None:
+ finish_exp = glotexp.Lambda(
+ this=finish(Column(accumulator_name)).expression,
+ expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))],
+ )
+ return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
+ return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
+
+
+def transform(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+ target_row_name: str = "x",
+ row_count_name: str = "i",
+) -> Column:
+ num_arguments = len(signature(f).parameters)
+ expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
+ columns = [Column(target_row_name)]
+ if num_arguments > 1:
+ columns.append(Column(row_count_name))
+ expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
+
+ f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
+
+
+def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(target_row_name)).expression,
+ expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
+ )
+ return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
+
+
+def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(target_row_name)).expression,
+ expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
+ )
+
+ return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
+
+
+def filter(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+ target_row_name: str = "x",
+ row_count_name: str = "i",
+) -> Column:
+ num_arguments = len(signature(f).parameters)
+ expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
+ columns = [Column(target_row_name)]
+ if num_arguments > 1:
+ columns.append(Column(row_count_name))
+ expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
+
+ f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression))
+
+
+def zip_with(
+ left: ColumnOrName,
+ right: ColumnOrName,
+ f: t.Callable[[Column, Column], Column],
+ left_name: str = "x",
+ right_name: str = "y",
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(left_name), Column(right_name)).expression,
+ expressions=[
+ glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)),
+ glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)),
+ ],
+ )
+
+ return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
+
+
+def transform_keys(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
+
+
+def transform_values(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
+
+
+def map_filter(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
+
+
+def map_zip_with(
+ col1: ColumnOrName,
+ col2: ColumnOrName,
+ f: t.Union[t.Callable[[Column, Column, Column], Column]],
+ key_name: str = "k",
+ value1: str = "v1",
+ value2: str = "v2",
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value1), Column(value2)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)),
+ glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
+
+
+def _lambda_quoted(value: str) -> t.Optional[bool]:
+ return False if value == "_" else None
diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py
new file mode 100644
index 0000000..947aace
--- /dev/null
+++ b/sqlglot/dataframe/sql/group.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.operations import Operation, operation
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+
+class GroupedData:
+ def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
+ self._df = df.copy()
+ self.spark = df.spark
+ self.last_op = last_op
+ self.group_by_cols = group_by_cols
+
+ def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
+ func_name = func_name.lower()
+ return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
+
+ @operation(Operation.SELECT)
+ def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
+ columns = (
+ [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
+ if isinstance(exprs[0], dict)
+ else exprs
+ )
+ cols = self._df._ensure_and_normalize_cols(columns)
+
+ expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
+ *[x.expression for x in self.group_by_cols + cols], append=False
+ )
+ return self._df.copy(expression=expression)
+
+ def count(self) -> DataFrame:
+ return self.agg(F.count("*").alias("count"))
+
+ def mean(self, *cols: str) -> DataFrame:
+ return self.avg(*cols)
+
+ def avg(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("avg", cols))
+
+ def max(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("max", cols))
+
+ def min(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("min", cols))
+
+ def sum(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("sum", cols))
+
+ def pivot(self, *cols: str) -> DataFrame:
+ raise NotImplementedError("Sum distinct is not currently implemented")
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
new file mode 100644
index 0000000..1513946
--- /dev/null
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.helper import ensure_list
+
+NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
+ expr = ensure_list(expr)
+ expressions = _ensure_expressions(expr)
+ for expression in expressions:
+ identifiers = expression.find_all(exp.Identifier)
+ for identifier in identifiers:
+ replace_alias_name_with_cte_name(spark, expression_context, identifier)
+ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
+
+
+def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+ if id.alias_or_name in spark.name_to_sequence_id_mapping:
+ for cte in reversed(expression_context.ctes):
+ if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
+ _set_alias_name(id, cte.alias_or_name)
+ break
+
+
+def replace_branch_and_sequence_ids_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
+ if id.alias_or_name in spark.known_ids:
+ # Check if we have a join and if both the tables in that join share a common branch id
+ # If so we need to have this reference the left table by default unless the id is a sequence
+ # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
+ # be common in practice
+ if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
+ join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
+ ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
+ assert len(ctes_in_join) == 2
+ _set_alias_name(id, ctes_in_join[0].alias_or_name)
+ return
+
+ for cte in reversed(expression_context.ctes):
+ if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
+ _set_alias_name(id, cte.alias_or_name)
+ return
+
+
+def _set_alias_name(id: exp.Identifier, name: str):
+ id.set("this", name)
+
+
+def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
+ values = ensure_list(values)
+ results = []
+ for value in values:
+ if isinstance(value, str):
+ results.append(Column.ensure_col(value).expression)
+ elif isinstance(value, Column):
+ results.append(value.expression)
+ elif isinstance(value, exp.Expression):
+ results.append(value)
+ else:
+ raise ValueError(f"Got an invalid type to normalize: {type(value)}")
+ return results
diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py
new file mode 100644
index 0000000..d51335c
--- /dev/null
+++ b/sqlglot/dataframe/sql/operations.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+import functools
+import typing as t
+from enum import IntEnum
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.group import GroupedData
+
+
+class Operation(IntEnum):
+ INIT = -1
+ NO_OP = 0
+ FROM = 1
+ WHERE = 2
+ GROUP_BY = 3
+ HAVING = 4
+ SELECT = 5
+ ORDER_BY = 6
+ LIMIT = 7
+
+
+def operation(op: Operation):
+ """
+ Decorator used around DataFrame methods to indicate what type of operation is being performed from the
+ ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
+ included with the previous operation.
+
+ Ex: After a user does a join we want to allow them to select which columns for the different
+ tables that they want to carry through to the following operation. If we put that join in
+ a CTE preemptively then the user would not have a chance to select which column they want
+ in cases where there is overlap in names.
+ """
+
+ def decorator(func: t.Callable):
+ @functools.wraps(func)
+ def wrapper(self: DataFrame, *args, **kwargs):
+ if self.last_op == Operation.INIT:
+ self = self._convert_leaf_to_cte()
+ self.last_op = Operation.NO_OP
+ last_op = self.last_op
+ new_op = op if op != Operation.NO_OP else last_op
+ if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
+ self = self._convert_leaf_to_cte()
+ df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
+ df.last_op = new_op # type: ignore
+ return df
+
+ wrapper.__wrapped__ = func # type: ignore
+ return wrapper
+
+ return decorator
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
new file mode 100644
index 0000000..4830035
--- /dev/null
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -0,0 +1,79 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.helper import object_to_dict
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+class DataFrameReader:
+ def __init__(self, spark: SparkSession):
+ self.spark = spark
+
+ def table(self, tableName: str) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ sqlglot.schema.add_table(tableName)
+ return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+
+
+class DataFrameWriter:
+ def __init__(
+ self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ ):
+ self._df = df
+ self._spark = spark or df.spark
+ self._mode = mode
+ self._by_name = by_name
+
+ def copy(self, **kwargs) -> DataFrameWriter:
+ return DataFrameWriter(
+ **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ )
+
+ def sql(self, **kwargs) -> t.List[str]:
+ return self._df.sql(**kwargs)
+
+ def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
+ return self.copy(_mode=saveMode)
+
+ @property
+ def byName(self):
+ return self.copy(by_name=True)
+
+ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ output_expression_container = exp.Insert(
+ **{
+ "this": exp.to_table(tableName),
+ "overwrite": overwrite,
+ }
+ )
+ df = self._df.copy(output_expression_container=output_expression_container)
+ if self._by_name:
+ columns = sqlglot.schema.column_names(tableName, only_visible=True)
+ df = df._convert_leaf_to_cte().select(*columns)
+
+ return self.copy(_df=df)
+
+ def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
+ if format is not None:
+ raise NotImplementedError("Providing Format in the save as table is not supported")
+ exists, replace, mode = None, None, mode or str(self._mode)
+ if mode == "append":
+ return self.insertInto(name)
+ if mode == "ignore":
+ exists = True
+ if mode == "overwrite":
+ replace = True
+ output_expression_container = exp.Create(
+ this=exp.to_table(name),
+ kind="TABLE",
+ exists=exists,
+ replace=replace,
+ )
+ return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
new file mode 100644
index 0000000..1ea86d1
--- /dev/null
+++ b/sqlglot/dataframe/sql/session.py
@@ -0,0 +1,148 @@
+from __future__ import annotations
+
+import typing as t
+import uuid
+from collections import defaultdict
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.dataframe import DataFrame
+from sqlglot.dataframe.sql.readwriter import DataFrameReader
+from sqlglot.dataframe.sql.types import StructType
+from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
+
+
+class SparkSession:
+ known_ids: t.ClassVar[t.Set[str]] = set()
+ known_branch_ids: t.ClassVar[t.Set[str]] = set()
+ known_sequence_ids: t.ClassVar[t.Set[str]] = set()
+ name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
+
+ def __init__(self):
+ self.incrementing_id = 1
+
+ def __getattr__(self, name: str) -> SparkSession:
+ return self
+
+ def __call__(self, *args, **kwargs) -> SparkSession:
+ return self
+
+ @property
+ def read(self) -> DataFrameReader:
+ return DataFrameReader(self)
+
+ def table(self, tableName: str) -> DataFrame:
+ return self.read.table(tableName)
+
+ def createDataFrame(
+ self,
+ data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
+ schema: t.Optional[SchemaInput] = None,
+ samplingRatio: t.Optional[float] = None,
+ verifySchema: bool = False,
+ ) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ if samplingRatio is not None or verifySchema:
+ raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
+ if schema is not None and (
+ not isinstance(schema, (StructType, str, list))
+ or (isinstance(schema, list) and not isinstance(schema[0], str))
+ ):
+ raise NotImplementedError("Only schema of either list or string of list supported")
+ if not data:
+ raise ValueError("Must provide data to create into a DataFrame")
+
+ column_mapping: t.Dict[str, t.Optional[str]]
+ if schema is not None:
+ column_mapping = get_column_mapping_from_schema_input(schema)
+ elif isinstance(data[0], dict):
+ column_mapping = {col_name.strip(): None for col_name in data[0]}
+ else:
+ column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
+
+ data_expressions = [
+ exp.Tuple(
+ expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
+ )
+ for row in data
+ ]
+
+ sel_columns = [
+ F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
+ for name, data_type in column_mapping.items()
+ ]
+
+ select_kwargs = {
+ "expressions": sel_columns,
+ "from": exp.From(
+ expressions=[
+ exp.Subquery(
+ this=exp.Values(expressions=data_expressions),
+ alias=exp.TableAlias(
+ this=exp.to_identifier(self._auto_incrementing_name),
+ columns=[exp.to_identifier(col_name) for col_name in column_mapping],
+ ),
+ )
+ ]
+ ),
+ }
+
+ sel_expression = exp.Select(**select_kwargs)
+ return DataFrame(self, sel_expression)
+
+ def sql(self, sqlQuery: str) -> DataFrame:
+ expression = sqlglot.parse_one(sqlQuery, read="spark")
+ if isinstance(expression, exp.Select):
+ df = DataFrame(self, expression)
+ df = df._convert_leaf_to_cte()
+ elif isinstance(expression, (exp.Create, exp.Insert)):
+ select_expression = expression.expression.copy()
+ if isinstance(expression, exp.Insert):
+ select_expression.set("with", expression.args.get("with"))
+ expression.set("with", None)
+ del expression.args["expression"]
+ df = DataFrame(self, select_expression, output_expression_container=expression)
+ df = df._convert_leaf_to_cte()
+ else:
+ raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
+ return df
+
+ @property
+ def _auto_incrementing_name(self) -> str:
+ name = f"a{self.incrementing_id}"
+ self.incrementing_id += 1
+ return name
+
+ @property
+ def _random_name(self) -> str:
+ return f"a{str(uuid.uuid4())[:8]}"
+
+ @property
+ def _random_branch_id(self) -> str:
+ id = self._random_id
+ self.known_branch_ids.add(id)
+ return id
+
+ @property
+ def _random_sequence_id(self):
+ id = self._random_id
+ self.known_sequence_ids.add(id)
+ return id
+
+ @property
+ def _random_id(self) -> str:
+ id = f"a{str(uuid.uuid4())[:8]}"
+ self.known_ids.add(id)
+ return id
+
+ @property
+ def _join_hint_names(self) -> t.Set[str]:
+ return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
+
+ def _add_alias_to_mapping(self, name: str, sequence_id: str):
+ self.name_to_sequence_id_mapping[name].append(sequence_id)
diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py
new file mode 100644
index 0000000..b3dcc12
--- /dev/null
+++ b/sqlglot/dataframe/sql/transforms.py
@@ -0,0 +1,9 @@
+import typing as t
+
+from sqlglot import expressions as exp
+
+
+def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
+ if isinstance(node, exp.Identifier) and node in replacement_mapping:
+ node = node.replace(replacement_mapping[node].copy())
+ return node
diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py
new file mode 100644
index 0000000..dc5c05a
--- /dev/null
+++ b/sqlglot/dataframe/sql/types.py
@@ -0,0 +1,208 @@
+import typing as t
+
+
+class DataType:
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + "()"
+
+ def __hash__(self) -> int:
+ return hash(str(self))
+
+ def __eq__(self, other: t.Any) -> bool:
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other: t.Any) -> bool:
+ return not self.__eq__(other)
+
+ def __str__(self) -> str:
+ return self.typeName()
+
+ @classmethod
+ def typeName(cls) -> str:
+ return cls.__name__[:-4].lower()
+
+ def simpleString(self) -> str:
+ return str(self)
+
+ def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
+ return str(self)
+
+
+class DataTypeWithLength(DataType):
+ def __init__(self, length: int):
+ self.length = length
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.length})"
+
+ def __str__(self) -> str:
+ return f"{self.typeName()}({self.length})"
+
+
+class StringType(DataType):
+ pass
+
+
+class CharType(DataTypeWithLength):
+ pass
+
+
+class VarcharType(DataTypeWithLength):
+ pass
+
+
+class BinaryType(DataType):
+ pass
+
+
+class BooleanType(DataType):
+ pass
+
+
+class DateType(DataType):
+ pass
+
+
+class TimestampType(DataType):
+ pass
+
+
+class TimestampNTZType(DataType):
+ @classmethod
+ def typeName(cls) -> str:
+ return "timestamp_ntz"
+
+
+class DecimalType(DataType):
+ def __init__(self, precision: int = 10, scale: int = 0):
+ self.precision = precision
+ self.scale = scale
+
+ def simpleString(self) -> str:
+ return f"decimal({self.precision}, {self.scale})"
+
+ def jsonValue(self) -> str:
+ return f"decimal({self.precision}, {self.scale})"
+
+ def __repr__(self) -> str:
+ return f"DecimalType({self.precision}, {self.scale})"
+
+
+class DoubleType(DataType):
+ pass
+
+
+class FloatType(DataType):
+ pass
+
+
+class ByteType(DataType):
+ def __str__(self) -> str:
+ return "tinyint"
+
+
+class IntegerType(DataType):
+ def __str__(self) -> str:
+ return "int"
+
+
+class LongType(DataType):
+ def __str__(self) -> str:
+ return "bigint"
+
+
+class ShortType(DataType):
+ def __str__(self) -> str:
+ return "smallint"
+
+
+class ArrayType(DataType):
+ def __init__(self, elementType: DataType, containsNull: bool = True):
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self) -> str:
+ return f"ArrayType({self.elementType, str(self.containsNull)}"
+
+ def simpleString(self) -> str:
+ return f"array<{self.elementType.simpleString()}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull,
+ }
+
+
+class MapType(DataType):
+ def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self) -> str:
+ return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
+
+ def simpleString(self) -> str:
+ return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull,
+ }
+
+
+class StructField(DataType):
+ def __init__(
+ self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
+ ):
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+ self.metadata = metadata or {}
+
+ def __repr__(self) -> str:
+ return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
+
+ def simpleString(self) -> str:
+ return f"{self.name}:{self.dataType.simpleString()}"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable,
+ "metadata": self.metadata,
+ }
+
+
+class StructType(DataType):
+ def __init__(self, fields: t.Optional[t.List[StructField]] = None):
+ if not fields:
+ self.fields = []
+ self.names = []
+ else:
+ self.fields = fields
+ self.names = [f.name for f in fields]
+
+ def __iter__(self) -> t.Iterator[StructField]:
+ return iter(self.fields)
+
+ def __len__(self) -> int:
+ return len(self.fields)
+
+ def __repr__(self) -> str:
+ return f"StructType({', '.join(str(field) for field in self)})"
+
+ def simpleString(self) -> str:
+ return f"struct<{', '.join(x.simpleString() for x in self)}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
+
+ def fieldNames(self) -> t.List[str]:
+ return list(self.names)
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
new file mode 100644
index 0000000..575d18a
--- /dev/null
+++ b/sqlglot/dataframe/sql/util.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import types
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import SchemaInput
+
+
+def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
+ if isinstance(schema, dict):
+ return schema
+ elif isinstance(schema, str):
+ col_name_type_strs = [x.strip() for x in schema.split(",")]
+ return {
+ name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
+ for name_type_str in col_name_type_strs
+ }
+ elif isinstance(schema, types.StructType):
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
+ return {x.strip(): None for x in schema} # type: ignore
+
+
+def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
+ if not expression.args.get("joins"):
+ return []
+
+ left_table = expression.args["from"].args["expressions"][0]
+ other_tables = [join.this for join in expression.args["joins"]]
+ return [left_table] + other_tables
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
new file mode 100644
index 0000000..842f366
--- /dev/null
+++ b/sqlglot/dataframe/sql/window.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+import sys
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.helper import flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrName
+
+
+class Window:
+ _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
+ _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
+ _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
+ _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
+
+ unboundedPreceding: int = _JAVA_MIN_LONG
+
+ unboundedFollowing: int = _JAVA_MAX_LONG
+
+ currentRow: int = 0
+
+ @classmethod
+ def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ return WindowSpec().partitionBy(*cols)
+
+ @classmethod
+ def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ return WindowSpec().orderBy(*cols)
+
+ @classmethod
+ def rowsBetween(cls, start: int, end: int) -> WindowSpec:
+ return WindowSpec().rowsBetween(start, end)
+
+ @classmethod
+ def rangeBetween(cls, start: int, end: int) -> WindowSpec:
+ return WindowSpec().rangeBetween(start, end)
+
+
+class WindowSpec:
+ def __init__(self, expression: exp.Expression = exp.Window()):
+ self.expression = expression
+
+ def copy(self):
+ return WindowSpec(self.expression.copy())
+
+ def sql(self, **kwargs) -> str:
+ return self.expression.sql(dialect="spark", **kwargs)
+
+ def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ from sqlglot.dataframe.sql.column import Column
+
+ cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [Column.ensure_col(x).expression for x in cols]
+ window_spec = self.copy()
+ partition_by_expressions = window_spec.expression.args.get("partition_by", [])
+ partition_by_expressions.extend(expressions)
+ window_spec.expression.set("partition_by", partition_by_expressions)
+ return window_spec
+
+ def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ from sqlglot.dataframe.sql.column import Column
+
+ cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [Column.ensure_col(x).expression for x in cols]
+ window_spec = self.copy()
+ if window_spec.expression.args.get("order") is None:
+ window_spec.expression.set("order", exp.Order(expressions=[]))
+ order_by = window_spec.expression.args["order"].expressions
+ order_by.extend(expressions)
+ window_spec.expression.args["order"].set("expressions", order_by)
+ return window_spec
+
+ def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
+ kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
+ if start == Window.currentRow:
+ kwargs["start"] = "CURRENT ROW"
+ else:
+ kwargs = {
+ **kwargs,
+ **{
+ "start_side": "PRECEDING",
+ "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
+ },
+ }
+ if end == Window.currentRow:
+ kwargs["end"] = "CURRENT ROW"
+ else:
+ kwargs = {
+ **kwargs,
+ **{
+ "end_side": "FOLLOWING",
+ "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
+ },
+ }
+ return kwargs
+
+ def rowsBetween(self, start: int, end: int) -> WindowSpec:
+ window_spec = self.copy()
+ spec = self._calc_start_end(start, end)
+ spec["kind"] = "ROWS"
+ window_spec.expression.set(
+ "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ )
+ return window_spec
+
+ def rangeBetween(self, start: int, end: int) -> WindowSpec:
+ window_spec = self.copy()
+ spec = self._calc_start_end(start, end)
+ spec["kind"] = "RANGE"
+ window_spec.expression.set(
+ "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ )
+ return window_spec