diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:23 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:23 +0000 |
commit | dab6ba29e8eb9a5c2890ac3be8eab6e994aeb10e (patch) | |
tree | 0d209cfc6f7b9c794c254601c29aa5d8b9414876 | |
parent | Adding upstream version 7.1.3. (diff) | |
download | sqlglot-dab6ba29e8eb9a5c2890ac3be8eab6e994aeb10e.tar.xz sqlglot-dab6ba29e8eb9a5c2890ac3be8eab6e994aeb10e.zip |
Adding upstream version 9.0.1.upstream/9.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
87 files changed, 7995 insertions, 422 deletions
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 18ce8b6..a3f151b 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -20,7 +20,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r requirements.txt + python -m pip install -r dev-requirements.txt - name: Run checks (linter, code style, tests) run: | ./run_checks.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index cb76969..4fc508f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,21 @@ Changelog ========= +v9.0.0 +------ + +Changes: + +- Breaking : Changed AST hierarchy of exp.Table with exp.Alias. Before Tables were children's of their aliases, but in order to simplify the AST and fix some issues, Tables now have an alias property. + +v8.0.0 +------ + +Changes: + +- Breaking : New add\_table method in Schema ABC. +- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental. + v7.1.0 ------ @@ -317,6 +317,7 @@ Dialect["custom"] ## Run Tests and Lint ``` pip install -r requirements.txt +# set `SKIP_INTEGRATION=1` to skip integration tests ./run_checks.sh ``` diff --git a/requirements.txt b/dev-requirements.txt index b2308e5..336ecf4 100644 --- a/requirements.txt +++ b/dev-requirements.txt @@ -2,5 +2,7 @@ autoflake black duckdb isort +mypy pandas +pyspark python-dateutil diff --git a/run_checks.sh b/run_checks.sh index b6e559d..b13a61c 100755 --- a/run_checks.sh +++ b/run_checks.sh @@ -11,4 +11,5 @@ python -m autoflake -i -r ${RETURN_ERROR_CODE} \ sqlglot/ tests/ python -m isort --profile black sqlglot/ tests/ python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/ +python -m mypy sqlglot tests python -m unittest diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..bbd20d0 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,15 @@ +[mypy] +disallow_untyped_calls = False +no_implicit_optional = True + +[mypy-sqlglot.*] +ignore_errors = True + +[mypy-sqlglot.dataframe.*] +ignore_errors = False + +[mypy-tests.*] +ignore_errors = True + +[mypy-tests.dataframe.*] +ignore_errors = False diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 247085b..7841c11 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -21,12 +21,15 @@ from sqlglot.expressions import table_ as table from sqlglot.expressions import union from sqlglot.generator import Generator from sqlglot.parser import Parser +from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "7.1.3" +__version__ = "9.0.1" pretty = False +schema = MappingSchema() + def parse(sql, read=None, **opts): """ diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index 4161259..c0fa380 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -40,8 +40,8 @@ parser.add_argument( "--error-level", dest="error_level", type=str, - default="RAISE", - help="IGNORE, WARN, RAISE (default)", + default="IMMEDIATE", + help="IGNORE, WARN, RAISE, IMMEDIATE (default)", ) diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md new file mode 100644 index 0000000..54d3856 --- /dev/null +++ b/sqlglot/dataframe/README.md @@ -0,0 +1,224 @@ +# PySpark DataFrame SQL Generator + +This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). + +Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported. + +# How to use + +## Instructions +* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library +* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe` +* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)` + * The column structure can be defined the following ways: + * Dictionary where the keys are column names and values are string of the Spark SQL type name + * Ex: {'cola': 'string', 'colb': 'int'} + * PySpark DataFrame `StructType` similar to when using `createDataFrame` + * Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])` + * A string of names and types similar to what is supported in `createDataFrame` + * Ex: `cola: STRING, colb: INT` + * [Not Recommended] A list of string column names without type + * Ex: ['cola', 'colb'] + * The lack of types may limit functionality in future releases + * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally +* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command + * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects. + * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects + * Ex: `.sql(pretty=True, dialect='bigquery')` + +## Examples + +```python +import sqlglot +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.dataframe.sql import functions as F + +sqlglot.schema.add_table('employee', { + 'employee_id': 'INT', + 'fname': 'STRING', + 'lname': 'STRING', + 'age': 'INT', +}) # Register the table structure prior to reading from the table + +spark = SparkSession() + +df = ( + spark + .table('employee') + .groupBy(F.col("age")) + .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) +) + +print(df.sql(pretty=True)) # Spark will be the dialect used by default +``` +Output: +```sparksql +SELECT + `employee`.`age` AS `age`, + COUNT(DISTINCT `employee`.`employee_id`) AS `num_employees` +FROM `employee` AS `employee` +GROUP BY + `employee`.`age` +``` + +## Registering Custom Schema Class + +The step of adding `sqlglot.schema.add_table` can be skipped if you have the column structure stored externally like in a file or from an external metadata table. This can be done by writing a class that implements the `sqlglot.schema.Schema` abstract class and then assigning that class to `sqlglot.schema`. + +```python +import sqlglot +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.dataframe.sql import functions as F +from sqlglot.schema import Schema + + +class ExternalSchema(Schema): + ... + +sqlglot.schema = ExternalSchema() + +spark = SparkSession() + +df = ( + spark + .table('employee') + .groupBy(F.col("age")) + .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) +) + +print(df.sql(pretty=True)) +``` + +## Example Implementations + +### Bigquery +```python +from google.cloud import bigquery +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql import functions as F + +client = bigquery.Client() + +data = [ + (1, "Jack", "Shephard", 34), + (2, "John", "Locke", 48), + (3, "Kate", "Austen", 34), + (4, "Claire", "Littleton", 22), + (5, "Hugo", "Reyes", 26), +] +schema = types.StructType([ + types.StructField('employee_id', types.IntegerType(), False), + types.StructField('fname', types.StringType(), False), + types.StructField('lname', types.StringType(), False), + types.StructField('age', types.IntegerType(), False), +]) + +sql_statements = ( + SparkSession() + .createDataFrame(data, schema) + .groupBy(F.col("age")) + .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) + .sql(dialect="bigquery") +) + +result = None +for sql in sql_statements: + result = client.query(sql) + +assert result is not None +for row in client.query(result): + print(f"Age: {row['age']}, Num Employees: {row['num_employees']}") +``` + +### Snowflake +```python +import os + +import snowflake.connector +from sqlglot.dataframe.session import SparkSession +from sqlglot.dataframe import types +from sqlglot.dataframe import functions as F + +ctx = snowflake.connector.connect( + user=os.environ["SNOWFLAKE_USER"], + password=os.environ["SNOWFLAKE_PASS"], + account=os.environ["SNOWFLAKE_ACCOUNT"] +) +cs = ctx.cursor() + +data = [ + (1, "Jack", "Shephard", 34), + (2, "John", "Locke", 48), + (3, "Kate", "Austen", 34), + (4, "Claire", "Littleton", 22), + (5, "Hugo", "Reyes", 26), +] +schema = types.StructType([ + types.StructField('employee_id', types.IntegerType(), False), + types.StructField('fname', types.StringType(), False), + types.StructField('lname', types.StringType(), False), + types.StructField('age', types.IntegerType(), False), +]) + +sql_statements = ( + SparkSession() + .createDataFrame(data, schema) + .groupBy(F.col("age")) + .agg(F.countDistinct(F.col("lname")).alias("num_employees")) + .sql(dialect="snowflake") +) + +try: + for sql in sql_statements: + cs.execute(sql) + results = cs.fetchall() + for row in results: + print(f"Age: {row[0]}, Num Employees: {row[1]}") +finally: + cs.close() +ctx.close() +``` + +### Spark +```python +from pyspark.sql.session import SparkSession as PySparkSession +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql import functions as F + +data = [ + (1, "Jack", "Shephard", 34), + (2, "John", "Locke", 48), + (3, "Kate", "Austen", 34), + (4, "Claire", "Littleton", 22), + (5, "Hugo", "Reyes", 26), +] +schema = types.StructType([ + types.StructField('employee_id', types.IntegerType(), False), + types.StructField('fname', types.StringType(), False), + types.StructField('lname', types.StringType(), False), + types.StructField('age', types.IntegerType(), False), +]) + +sql_statements = ( + SparkSession() + .createDataFrame(data, schema) + .groupBy(F.col("age")) + .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) + .sql(dialect="bigquery") +) + +pyspark = PySparkSession.builder.master("local[*]").getOrCreate() + +df = None +for sql in sql_statements: + df = pyspark.sql(sql) + +assert df is not None +df.show() +``` + +# Unsupportable Operations + +Any operation that lacks a way to represent it in SQL cannot be supported by this tool. An example of this would be rdd operations. Since the DataFrame API though is mostly modeled around SQL concepts most operations can be supported. diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/sqlglot/dataframe/__init__.py diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py new file mode 100644 index 0000000..3f90802 --- /dev/null +++ b/sqlglot/dataframe/sql/__init__.py @@ -0,0 +1,18 @@ +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions +from sqlglot.dataframe.sql.group import GroupedData +from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.dataframe.sql.window import Window, WindowSpec + +__all__ = [ + "SparkSession", + "DataFrame", + "GroupedData", + "Column", + "DataFrameNaFunctions", + "Window", + "WindowSpec", + "DataFrameReader", + "DataFrameWriter", +] diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi new file mode 100644 index 0000000..f1a03ea --- /dev/null +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -0,0 +1,20 @@ +from __future__ import annotations + +import datetime +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.column import Column + from sqlglot.dataframe.sql.types import StructType + +ColumnLiterals = t.TypeVar( + "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +) +ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) +ColumnOrLiteral = t.TypeVar( + "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +) +SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]) +OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]) diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py new file mode 100644 index 0000000..2391080 --- /dev/null +++ b/sqlglot/dataframe/sql/column.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import typing as t + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.types import DataType +from sqlglot.helper import flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrLiteral + from sqlglot.dataframe.sql.window import WindowSpec + + +class Column: + def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + if isinstance(expression, Column): + expression = expression.expression # type: ignore + elif expression is None or not isinstance(expression, (str, exp.Expression)): + expression = self._lit(expression).expression # type: ignore + self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + def __repr__(self): + return repr(self.expression) + + def __hash__(self): + return hash(self.expression) + + def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore + return self.binary_op(exp.EQ, other) + + def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore + return self.binary_op(exp.NEQ, other) + + def __gt__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.GT, other) + + def __ge__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.GTE, other) + + def __lt__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.LT, other) + + def __le__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.LTE, other) + + def __and__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.And, other) + + def __or__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Or, other) + + def __mod__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Mod, other) + + def __add__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Add, other) + + def __sub__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Sub, other) + + def __mul__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Mul, other) + + def __truediv__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Div, other) + + def __div__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Div, other) + + def __neg__(self) -> Column: + return self.unary_op(exp.Neg) + + def __radd__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Add, other) + + def __rsub__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Sub, other) + + def __rmul__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Mul, other) + + def __rdiv__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Div, other) + + def __rtruediv__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Div, other) + + def __rmod__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Mod, other) + + def __pow__(self, power: ColumnOrLiteral, modulo=None): + return Column(exp.Pow(this=self.expression, power=Column(power).expression)) + + def __rpow__(self, power: ColumnOrLiteral): + return Column(exp.Pow(this=Column(power).expression, power=self.expression)) + + def __invert__(self): + return self.unary_op(exp.Not) + + def __rand__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.And, other) + + def __ror__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Or, other) + + @classmethod + def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + return cls(value) + + @classmethod + def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: + return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] + + @classmethod + def _lit(cls, value: ColumnOrLiteral) -> Column: + if isinstance(value, dict): + columns = [cls._lit(v).alias(k).expression for k, v in value.items()] + return cls(exp.Struct(expressions=columns)) + return cls(exp.convert(value)) + + @classmethod + def invoke_anonymous_function( + cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] + ) -> Column: + columns = [] if column is None else [cls.ensure_col(column)] + column_args = [cls.ensure_col(arg) for arg in args] + expressions = [x.expression for x in columns + column_args] + new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) + return Column(new_expression) + + @classmethod + def invoke_expression_over_column( + cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs + ) -> Column: + ensured_column = None if column is None else cls.ensure_col(column) + new_expression = ( + callable_expression(**kwargs) + if ensured_column is None + else callable_expression(this=ensured_column.column_expression, **kwargs) + ) + return Column(new_expression) + + def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: + return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + + def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: + return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + + def unary_op(self, klass: t.Callable, **kwargs) -> Column: + return Column(klass(this=self.column_expression, **kwargs)) + + @property + def is_alias(self): + return isinstance(self.expression, exp.Alias) + + @property + def is_column(self): + return isinstance(self.expression, exp.Column) + + @property + def column_expression(self) -> exp.Column: + return self.expression.unalias() + + @property + def alias_or_name(self) -> str: + return self.expression.alias_or_name + + @classmethod + def ensure_literal(cls, value) -> Column: + from sqlglot.dataframe.sql.functions import lit + + if isinstance(value, cls): + value = value.expression + if not isinstance(value, exp.Literal): + return lit(value) + return Column(value) + + def copy(self) -> Column: + return Column(self.expression.copy()) + + def set_table_name(self, table_name: str, copy=False) -> Column: + expression = self.expression.copy() if copy else self.expression + expression.set("table", exp.to_identifier(table_name)) + return Column(expression) + + def sql(self, **kwargs) -> Column: + return self.expression.sql(**{"dialect": "spark", **kwargs}) + + def alias(self, name: str) -> Column: + new_expression = exp.alias_(self.column_expression, name) + return Column(new_expression) + + def asc(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) + return Column(new_expression) + + def desc(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) + return Column(new_expression) + + asc_nulls_first = asc + + def asc_nulls_last(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) + return Column(new_expression) + + def desc_nulls_first(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) + return Column(new_expression) + + desc_nulls_last = desc + + def when(self, condition: Column, value: t.Any) -> Column: + from sqlglot.dataframe.sql.functions import when + + column_with_if = when(condition, value) + if not isinstance(self.expression, exp.Case): + return column_with_if + new_column = self.copy() + new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) + return new_column + + def otherwise(self, value: t.Any) -> Column: + from sqlglot.dataframe.sql.functions import lit + + true_value = value if isinstance(value, Column) else lit(value) + new_column = self.copy() + new_column.expression.set("default", true_value.column_expression) + return new_column + + def isNull(self) -> Column: + new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) + return Column(new_expression) + + def isNotNull(self) -> Column: + new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) + return Column(new_expression) + + def cast(self, dataType: t.Union[str, DataType]): + """ + Functionality Difference: PySpark cast accepts a datatype instance of the datatype class + Sqlglot doesn't currently replicate this class so it only accepts a string + """ + if isinstance(dataType, DataType): + dataType = dataType.simpleString() + new_expression = exp.Cast(this=self.column_expression, to=dataType) + return Column(new_expression) + + def startswith(self, value: t.Union[str, Column]) -> Column: + value = self._lit(value) if not isinstance(value, Column) else value + return self.invoke_anonymous_function(self, "STARTSWITH", value) + + def endswith(self, value: t.Union[str, Column]) -> Column: + value = self._lit(value) if not isinstance(value, Column) else value + return self.invoke_anonymous_function(self, "ENDSWITH", value) + + def rlike(self, regexp: str) -> Column: + return self.invoke_expression_over_column( + column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression + ) + + def like(self, other: str): + return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + + def ilike(self, other: str): + return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + + def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: + startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos + length = self._lit(length) if not isinstance(length, Column) else length + return Column.invoke_expression_over_column( + self, exp.Substring, start=startPos.expression, length=length.expression + ) + + def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): + columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [self._lit(x).expression for x in columns] + return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore + + def between( + self, + lowerBound: t.Union[ColumnOrLiteral], + upperBound: t.Union[ColumnOrLiteral], + ) -> Column: + lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + return Column( + exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + ) + + def over(self, window: WindowSpec) -> Column: + window_expression = window.expression.copy() + window_expression.set("this", self.column_expression) + return Column(window_expression) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py new file mode 100644 index 0000000..322dcf2 --- /dev/null +++ b/sqlglot/dataframe/sql/dataframe.py @@ -0,0 +1,730 @@ +from __future__ import annotations + +import functools +import typing as t +import zlib +from copy import copy + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.group import GroupedData +from sqlglot.dataframe.sql.normalize import normalize +from sqlglot.dataframe.sql.operations import Operation, operation +from sqlglot.dataframe.sql.readwriter import DataFrameWriter +from sqlglot.dataframe.sql.transforms import replace_id_value +from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.dataframe.sql.window import Window +from sqlglot.helper import ensure_list, object_to_dict +from sqlglot.optimizer import optimize as optimize_func +from sqlglot.optimizer.qualify_columns import qualify_columns + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql.session import SparkSession + + +JOIN_HINTS = { + "BROADCAST", + "BROADCASTJOIN", + "MAPJOIN", + "MERGE", + "SHUFFLEMERGE", + "MERGEJOIN", + "SHUFFLE_HASH", + "SHUFFLE_REPLICATE_NL", +} + + +class DataFrame: + def __init__( + self, + spark: SparkSession, + expression: exp.Select, + branch_id: t.Optional[str] = None, + sequence_id: t.Optional[str] = None, + last_op: Operation = Operation.INIT, + pending_hints: t.Optional[t.List[exp.Expression]] = None, + output_expression_container: t.Optional[OutputExpressionContainer] = None, + **kwargs, + ): + self.spark = spark + self.expression = expression + self.branch_id = branch_id or self.spark._random_branch_id + self.sequence_id = sequence_id or self.spark._random_sequence_id + self.last_op = last_op + self.pending_hints = pending_hints or [] + self.output_expression_container = output_expression_container or exp.Select() + + def __getattr__(self, column_name: str) -> Column: + return self[column_name] + + def __getitem__(self, column_name: str) -> Column: + column_name = f"{self.branch_id}.{column_name}" + return Column(column_name) + + def __copy__(self): + return self.copy() + + @property + def sparkSession(self): + return self.spark + + @property + def write(self): + return DataFrameWriter(self) + + @property + def latest_cte_name(self) -> str: + if not self.expression.ctes: + from_exp = self.expression.args["from"] + if from_exp.alias_or_name: + return from_exp.alias_or_name + table_alias = from_exp.find(exp.TableAlias) + if not table_alias: + raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + return table_alias.alias_or_name + return self.expression.ctes[-1].alias + + @property + def pending_join_hints(self): + return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] + + @property + def pending_partition_hints(self): + return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] + + @property + def columns(self) -> t.List[str]: + return self.expression.named_selects + + @property + def na(self) -> DataFrameNaFunctions: + return DataFrameNaFunctions(self) + + def _replace_cte_names_with_hashes(self, expression: exp.Select): + expression = expression.copy() + ctes = expression.ctes + replacement_mapping = {} + for cte in ctes: + old_name_id = cte.args["alias"].this + new_hashed_id = exp.to_identifier( + self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] + ) + replacement_mapping[old_name_id] = new_hashed_id + cte.set("alias", exp.TableAlias(this=new_hashed_id)) + expression = expression.transform(replace_id_value, replacement_mapping) + return expression + + def _create_cte_from_expression( + self, + expression: exp.Expression, + branch_id: t.Optional[str] = None, + sequence_id: t.Optional[str] = None, + **kwargs, + ) -> t.Tuple[exp.CTE, str]: + name = self.spark._random_name + expression_to_cte = expression.copy() + expression_to_cte.set("with", None) + cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] + cte.set("branch_id", branch_id or self.branch_id) + cte.set("sequence_id", sequence_id or self.sequence_id) + return cte, name + + def _ensure_list_of_columns( + self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] + ) -> t.List[Column]: + columns = ensure_list(cols) + columns = Column.ensure_cols(columns) + return columns + + def _ensure_and_normalize_cols(self, cols): + cols = self._ensure_list_of_columns(cols) + normalize(self.spark, self.expression, cols) + return cols + + def _ensure_and_normalize_col(self, col): + col = Column.ensure_col(col) + normalize(self.spark, self.expression, col) + return col + + def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: + df = self._resolve_pending_hints() + sequence_id = sequence_id or df.sequence_id + expression = df.expression.copy() + cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) + new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + sel_columns = df._get_outer_select_columns(cte_expression) + new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + return df.copy(expression=new_expression, sequence_id=sequence_id) + + def _resolve_pending_hints(self) -> DataFrame: + df = self.copy() + if not self.pending_hints: + return df + expression = df.expression + hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) + for hint in df.pending_partition_hints: + hint_expression.args.get("expressions").append(hint) + df.pending_hints.remove(hint) + + join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + if join_aliases: + for hint in df.pending_join_hints: + for sequence_id_expression in hint.expressions: + sequence_id_or_name = sequence_id_expression.alias_or_name + sequence_ids_to_match = [sequence_id_or_name] + if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + matching_ctes = [ + cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + ] + for matching_cte in matching_ctes: + if matching_cte.alias_or_name in join_aliases: + sequence_id_expression.set("this", matching_cte.args["alias"].this) + df.pending_hints.remove(hint) + break + hint_expression.args.get("expressions").append(hint) + if hint_expression.expressions: + expression.set("hint", hint_expression) + return df + + def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: + hint_name = hint_name.upper() + hint_expression = ( + exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + if hint_name in JOIN_HINTS + else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + ) + new_df = self.copy() + new_df.pending_hints.append(hint_expression) + return new_df + + def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): + other_df = other._convert_leaf_to_cte() + base_expression = self.expression.copy() + base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) + all_ctes = base_expression.ctes + other_df.expression.set("with", None) + base_expression.set("with", None) + operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) + operation.set("with", exp.With(expressions=all_ctes)) + return self.copy(expression=operation)._convert_leaf_to_cte() + + def _cache(self, storage_level: str): + df = self._convert_leaf_to_cte() + df.expression.ctes[-1].set("cache_storage_level", storage_level) + return df + + @classmethod + def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: + expression = expression.copy() + with_expression = expression.args.get("with") + if with_expression: + existing_ctes = with_expression.expressions + existsing_cte_names = {x.alias_or_name for x in existing_ctes} + for cte in ctes: + if cte.alias_or_name not in existsing_cte_names: + existing_ctes.append(cte) + else: + existing_ctes = ctes + expression.set("with", exp.With(expressions=existing_ctes)) + return expression + + @classmethod + def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: + expression = item.expression if isinstance(item, DataFrame) else item + return [Column(x) for x in expression.find(exp.Select).expressions] + + @classmethod + def _create_hash_from_expression(cls, expression: exp.Select): + value = expression.sql(dialect="spark").encode("utf-8") + return f"t{zlib.crc32(value)}"[:6] + + def _get_select_expressions( + self, + ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: + select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + main_select_ctes: t.List[exp.CTE] = [] + for cte in self.expression.ctes: + cache_storage_level = cte.args.get("cache_storage_level") + if cache_storage_level: + select_expression = cte.this.copy() + select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) + select_expression.set("cte_alias_name", cte.alias_or_name) + select_expression.set("cache_storage_level", cache_storage_level) + select_expressions.append((exp.Cache, select_expression)) + else: + main_select_ctes.append(cte) + main_select = self.expression.copy() + if main_select_ctes: + main_select.set("with", exp.With(expressions=main_select_ctes)) + expression_select_pair = (type(self.output_expression_container), main_select) + select_expressions.append(expression_select_pair) # type: ignore + return select_expressions + + def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: + df = self._resolve_pending_hints() + select_expressions = df._get_select_expressions() + output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] + replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} + for expression_type, select_expression in select_expressions: + select_expression = select_expression.transform(replace_id_value, replacement_mapping) + if optimize: + select_expression = optimize_func(select_expression) + select_expression = df._replace_cte_names_with_hashes(select_expression) + expression: t.Union[exp.Select, exp.Cache, exp.Drop] + if expression_type == exp.Cache: + cache_table_name = df._create_hash_from_expression(select_expression) + cache_table = exp.to_table(cache_table_name) + original_alias_name = select_expression.args["cte_alias_name"] + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) + cache_storage_level = select_expression.args["cache_storage_level"] + options = [ + exp.Literal.string("storageLevel"), + exp.Literal.string(cache_storage_level), + ] + expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + # We will drop the "view" if it exists before running the cache table + output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) + elif expression_type == exp.Create: + expression = df.output_expression_container.copy() + expression.set("expression", select_expression) + elif expression_type == exp.Insert: + expression = df.output_expression_container.copy() + select_without_ctes = select_expression.copy() + select_without_ctes.set("with", None) + expression.set("expression", select_without_ctes) + if select_expression.ctes: + expression.set("with", exp.With(expressions=select_expression.ctes)) + elif expression_type == exp.Select: + expression = select_expression + else: + raise ValueError(f"Invalid expression type: {expression_type}") + output_expressions.append(expression) + + return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + + def copy(self, **kwargs) -> DataFrame: + return DataFrame(**object_to_dict(self, **kwargs)) + + @operation(Operation.SELECT) + def select(self, *cols, **kwargs) -> DataFrame: + cols = self._ensure_and_normalize_cols(cols) + kwargs["append"] = kwargs.get("append", False) + if self.expression.args.get("joins"): + ambiguous_cols = [col for col in cols if not col.column_expression.table] + if ambiguous_cols: + join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + cte_names_in_join = [x.this for x in join_table_identifiers] + for ambiguous_col in ambiguous_cols: + ctes_with_column = [ + cte + for cte in self.expression.ctes + if cte.alias_or_name in cte_names_in_join + and ambiguous_col.alias_or_name in cte.this.named_selects + ] + # If the select column does not specify a table and there is a join + # then we assume they are referring to the left table + if len(ctes_with_column) > 1: + table_identifier = self.expression.args["from"].args["expressions"][0].this + else: + table_identifier = ctes_with_column[0].args["alias"].this + ambiguous_col.expression.set("table", table_identifier) + expression = self.expression.select(*[x.expression for x in cols], **kwargs) + qualify_columns(expression, sqlglot.schema) + return self.copy(expression=expression, **kwargs) + + @operation(Operation.NO_OP) + def alias(self, name: str, **kwargs) -> DataFrame: + new_sequence_id = self.spark._random_sequence_id + df = self.copy() + for join_hint in df.pending_join_hints: + for expression in join_hint.expressions: + if expression.alias_or_name == self.sequence_id: + expression.set("this", Column.ensure_col(new_sequence_id).expression) + df.spark._add_alias_to_mapping(name, new_sequence_id) + return df._convert_leaf_to_cte(sequence_id=new_sequence_id) + + @operation(Operation.WHERE) + def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: + col = self._ensure_and_normalize_col(column) + return self.copy(expression=self.expression.where(col.expression)) + + filter = where + + @operation(Operation.GROUP_BY) + def groupBy(self, *cols, **kwargs) -> GroupedData: + columns = self._ensure_and_normalize_cols(cols) + return GroupedData(self, columns, self.last_op) + + @operation(Operation.SELECT) + def agg(self, *exprs, **kwargs) -> DataFrame: + cols = self._ensure_and_normalize_cols(exprs) + return self.groupBy().agg(*cols) + + @operation(Operation.FROM) + def join( + self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + ) -> DataFrame: + other_df = other_df._convert_leaf_to_cte() + pre_join_self_latest_cte_name = self.latest_cte_name + columns = self._ensure_and_normalize_cols(on) + join_type = how.replace("_", " ") + if isinstance(columns[0].expression, exp.Column): + join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_clause = functools.reduce( + lambda x, y: x & y, + [ + col.copy().set_table_name(pre_join_self_latest_cte_name) + == col.copy().set_table_name(other_df.latest_cte_name) + for col in columns + ], + ) + else: + if len(columns) > 1: + columns = [functools.reduce(lambda x, y: x & y, columns)] + join_clause = columns[0] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) + if i % 2 == 0 + else Column(x).set_table_name(other_df.latest_cte_name) + for i, x in enumerate(join_clause.expression.find_all(exp.Column)) + ] + self_columns = [ + column.set_table_name(pre_join_self_latest_cte_name, copy=True) + for column in self._get_outer_select_columns(self) + ] + other_columns = [ + column.set_table_name(other_df.latest_cte_name, copy=True) + for column in self._get_outer_select_columns(other_df) + ] + column_value_mapping = { + column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + for column in other_columns + self_columns + join_columns + } + all_columns = [ + column_value_mapping[name] + for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} + ] + new_df = self.copy( + expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + ) + new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) + new_df.pending_hints.extend(other_df.pending_hints) + new_df = new_df.select.__wrapped__(new_df, *all_columns) + return new_df + + @operation(Operation.ORDER_BY) + def orderBy( + self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + ) -> DataFrame: + """ + This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark + has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this + is unlikely to come up. + """ + columns = self._ensure_and_normalize_cols(cols) + pre_ordered_col_indexes = [ + x + for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + if x is not None + ] + if ascending is None: + ascending = [True] * len(columns) + elif not isinstance(ascending, list): + ascending = [ascending] * len(columns) + ascending = [bool(x) for i, x in enumerate(ascending)] + assert len(columns) == len( + ascending + ), "The length of items in ascending must equal the number of columns provided" + col_and_ascending = list(zip(columns, ascending)) + order_by_columns = [ + exp.Ordered(this=col.expression, desc=not asc) + if i not in pre_ordered_col_indexes + else columns[i].column_expression + for i, (col, asc) in enumerate(col_and_ascending) + ] + return self.copy(expression=self.expression.order_by(*order_by_columns)) + + sort = orderBy + + @operation(Operation.FROM) + def union(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Union, other, False) + + unionAll = union + + @operation(Operation.FROM) + def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): + l_columns = self.columns + r_columns = other.columns + if not allowMissingColumns: + l_expressions = l_columns + r_expressions = l_columns + else: + l_expressions = [] + r_expressions = [] + r_columns_unused = copy(r_columns) + for l_column in l_columns: + l_expressions.append(l_column) + if l_column in r_columns: + r_expressions.append(l_column) + r_columns_unused.remove(l_column) + else: + r_expressions.append(exp.alias_(exp.Null(), l_column)) + for r_column in r_columns_unused: + l_expressions.append(exp.alias_(exp.Null(), r_column)) + r_expressions.append(r_column) + r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + l_df = self.copy() + if allowMissingColumns: + l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) + return l_df._set_operation(exp.Union, r_df, False) + + @operation(Operation.FROM) + def intersect(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Intersect, other, True) + + @operation(Operation.FROM) + def intersectAll(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Intersect, other, False) + + @operation(Operation.FROM) + def exceptAll(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Except, other, False) + + @operation(Operation.SELECT) + def distinct(self) -> DataFrame: + return self.copy(expression=self.expression.distinct()) + + @operation(Operation.SELECT) + def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): + if not subset: + return self.distinct() + column_names = ensure_list(subset) + window = Window.partitionBy(*column_names).orderBy(*column_names) + return ( + self.copy() + .withColumn("row_num", F.row_number().over(window)) + .where(F.col("row_num") == F.lit(1)) + .drop("row_num") + ) + + @operation(Operation.FROM) + def dropna( + self, + how: str = "any", + thresh: t.Optional[int] = None, + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + minimum_non_null = thresh or 0 # will be determined later if thresh is null + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + if subset: + null_check_columns = self._ensure_and_normalize_cols(subset) + else: + null_check_columns = all_columns + if thresh is None: + minimum_num_nulls = 1 if how == "any" else len(null_check_columns) + else: + minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 + if minimum_num_nulls > len(null_check_columns): + raise RuntimeError( + f"The minimum num nulls for dropna must be less than or equal to the number of columns. " + f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" + ) + if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) + num_nulls = nulls_added_together.alias("num_nulls") + new_df = new_df.select(num_nulls, append=True) + filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) + final_df = filtered_df.select(*all_columns) + return final_df + + @operation(Operation.FROM) + def fillna( + self, + value: t.Union[ColumnLiterals], + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + """ + Functionality Difference: If you provide a value to replace a null and that type conflicts + with the type of the column then PySpark will just ignore your replacement. + This will try to cast them to be the same in some cases. So they won't always match. + Best to not mix types so make sure replacement is the same type as the column + + Possibility for improvement: Use `typeof` function to get the type of the column + and check if it matches the type of the value provided. If not then make it null. + """ + from sqlglot.dataframe.sql.functions import lit + + values = None + columns = None + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + all_column_mapping = {column.alias_or_name: column for column in all_columns} + if isinstance(value, dict): + values = value.values() + columns = self._ensure_and_normalize_cols(list(value)) + if not columns: + columns = self._ensure_and_normalize_cols(subset) if subset else all_columns + if not values: + values = [value] * len(columns) + value_columns = [lit(value) for value in values] + + null_replacement_mapping = { + column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + for column, value in zip(columns, value_columns) + } + null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} + null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + new_df = new_df.select(*null_replacement_columns) + return new_df + + @operation(Operation.FROM) + def replace( + self, + to_replace: t.Union[bool, int, float, str, t.List, t.Dict], + value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, + subset: t.Optional[t.Union[str, t.List[str]]] = None, + ) -> DataFrame: + from sqlglot.dataframe.sql.functions import lit + + old_values = None + subset = ensure_list(subset) + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + all_column_mapping = {column.alias_or_name: column for column in all_columns} + + columns = self._ensure_and_normalize_cols(subset) if subset else all_columns + if isinstance(to_replace, dict): + old_values = list(to_replace) + new_values = list(to_replace.values()) + elif not old_values and isinstance(to_replace, list): + assert isinstance(value, list), "value must be a list since the replacements are a list" + assert len(to_replace) == len(value), "the replacements and values must be the same length" + old_values = to_replace + new_values = value + else: + old_values = [to_replace] * len(columns) + new_values = [value] * len(columns) + old_values = [lit(value) for value in old_values] + new_values = [lit(value) for value in new_values] + + replacement_mapping = {} + for column in columns: + expression = Column(None) + for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): + if i == 0: + expression = F.when(column == old_value, new_value) + else: + expression = expression.when(column == old_value, new_value) # type: ignore + replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( + column.expression.alias_or_name + ) + + replacement_mapping = {**all_column_mapping, **replacement_mapping} + replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] + new_df = new_df.select(*replacement_columns) + return new_df + + @operation(Operation.SELECT) + def withColumn(self, colName: str, col: Column) -> DataFrame: + col = self._ensure_and_normalize_col(col) + existing_col_names = self.expression.named_selects + existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + if existing_col_index: + expression = self.expression.copy() + expression.expressions[existing_col_index] = col.expression + return self.copy(expression=expression) + return self.copy().select(col.alias(colName), append=True) + + @operation(Operation.SELECT) + def withColumnRenamed(self, existing: str, new: str): + expression = self.expression.copy() + existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + if not existing_columns: + raise ValueError("Tried to rename a column that doesn't exist") + for existing_column in existing_columns: + if isinstance(existing_column, exp.Column): + existing_column.replace(exp.alias_(existing_column.copy(), new)) + else: + existing_column.set("alias", exp.to_identifier(new)) + return self.copy(expression=expression) + + @operation(Operation.SELECT) + def drop(self, *cols: t.Union[str, Column]) -> DataFrame: + all_columns = self._get_outer_select_columns(self.expression) + drop_cols = self._ensure_and_normalize_cols(cols) + new_columns = [ + col + for col in all_columns + if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] + ] + return self.copy().select(*new_columns, append=False) + + @operation(Operation.LIMIT) + def limit(self, num: int) -> DataFrame: + return self.copy(expression=self.expression.limit(num)) + + @operation(Operation.NO_OP) + def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: + parameter_list = ensure_list(parameters) + parameter_columns = ( + self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + ) + return self._hint(name, parameter_columns) + + @operation(Operation.NO_OP) + def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: + num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + columns = self._ensure_and_normalize_cols(cols) + args = num_partitions + columns + return self._hint("repartition", args) + + @operation(Operation.NO_OP) + def coalesce(self, numPartitions: int) -> DataFrame: + num_partitions = Column.ensure_cols([numPartitions]) + return self._hint("coalesce", num_partitions) + + @operation(Operation.NO_OP) + def cache(self) -> DataFrame: + return self._cache(storage_level="MEMORY_AND_DISK") + + @operation(Operation.NO_OP) + def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: + """ + Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html + """ + return self._cache(storageLevel) + + +class DataFrameNaFunctions: + def __init__(self, df: DataFrame): + self.df = df + + def drop( + self, + how: str = "any", + thresh: t.Optional[int] = None, + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + def fill( + self, + value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + return self.df.fillna(value=value, subset=subset) + + def replace( + self, + to_replace: t.Union[bool, int, float, str, t.List, t.Dict], + value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, + subset: t.Optional[t.Union[str, t.List[str]]] = None, + ) -> DataFrame: + return self.df.replace(to_replace=to_replace, value=value, subset=subset) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py new file mode 100644 index 0000000..4c6de30 --- /dev/null +++ b/sqlglot/dataframe/sql/functions.py @@ -0,0 +1,1258 @@ +from __future__ import annotations + +import typing as t +from inspect import signature + +from sqlglot import expressions as glotexp +from sqlglot.dataframe.sql.column import Column +from sqlglot.helper import ensure_list +from sqlglot.helper import flatten as _flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName + from sqlglot.dataframe.sql.dataframe import DataFrame + + +def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column: + return Column(column_name) + + +def lit(value: t.Optional[t.Any] = None) -> Column: + if isinstance(value, str): + return Column(glotexp.Literal.string(str(value))) + return Column(value) + + +def greatest(*cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(col) for col in cols] + return Column.invoke_expression_over_column( + columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None + ) + + +def least(*cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(col) for col in cols] + return Column.invoke_expression_over_column( + columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None + ) + + +def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(x) for x in [col] + list(cols)] + return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns]))) + + +def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: + return count_distinct(col, *cols) + + +def when(condition: Column, value: t.Any) -> Column: + true_value = value if isinstance(value, Column) else lit(value) + return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) + + +def asc(col: ColumnOrName) -> Column: + return Column.ensure_col(col).asc() + + +def desc(col: ColumnOrName): + return Column.ensure_col(col).desc() + + +def broadcast(df: DataFrame) -> DataFrame: + return df.hint("broadcast") + + +def sqrt(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Sqrt) + + +def abs(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Abs) + + +def max(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Max) + + +def min(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Min) + + +def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAX_BY", ord) + + +def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MIN_BY", ord) + + +def count(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Count) + + +def sum(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Sum) + + +def avg(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Avg) + + +def mean(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MEAN") + + +def sumDistinct(col: ColumnOrName) -> Column: + return sum_distinct(col) + + +def sum_distinct(col: ColumnOrName) -> Column: + raise NotImplementedError("Sum distinct is not currently implemented") + + +def product(col: ColumnOrName) -> Column: + raise NotImplementedError("Product is not currently implemented") + + +def acos(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ACOS") + + +def acosh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ACOSH") + + +def asin(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ASIN") + + +def asinh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ASINH") + + +def atan(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ATAN") + + +def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: + return Column.invoke_anonymous_function(col1, "ATAN2", col2) + + +def atanh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ATANH") + + +def cbrt(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "CBRT") + + +def ceil(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Ceil) + + +def cos(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "COS") + + +def cosh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "COSH") + + +def cot(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "COT") + + +def csc(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "CSC") + + +def exp(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Exp) + + +def expm1(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "EXPM1") + + +def floor(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Floor) + + +def log10(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Log10) + + +def log1p(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "LOG1P") + + +def log2(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Log2) + + +def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: + if arg2 is None: + return Column.invoke_expression_over_column(arg1, glotexp.Ln) + return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression) + + +def rint(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "RINT") + + +def sec(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SEC") + + +def signum(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SIGNUM") + + +def sin(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SIN") + + +def sinh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SINH") + + +def tan(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "TAN") + + +def tanh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "TANH") + + +def toDegrees(col: ColumnOrName) -> Column: + return degrees(col) + + +def degrees(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DEGREES") + + +def toRadians(col: ColumnOrName) -> Column: + return radians(col) + + +def radians(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "RADIANS") + + +def bitwiseNOT(col: ColumnOrName) -> Column: + return bitwise_not(col) + + +def bitwise_not(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.BitwiseNot) + + +def asc_nulls_first(col: ColumnOrName) -> Column: + return Column.ensure_col(col).asc_nulls_first() + + +def asc_nulls_last(col: ColumnOrName) -> Column: + return Column.ensure_col(col).asc_nulls_last() + + +def desc_nulls_first(col: ColumnOrName) -> Column: + return Column.ensure_col(col).desc_nulls_first() + + +def desc_nulls_last(col: ColumnOrName) -> Column: + return Column.ensure_col(col).desc_nulls_last() + + +def stddev(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Stddev) + + +def stddev_samp(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.StddevSamp) + + +def stddev_pop(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.StddevPop) + + +def variance(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Variance) + + +def var_samp(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Variance) + + +def var_pop(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.VariancePop) + + +def skewness(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SKEWNESS") + + +def kurtosis(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "KURTOSIS") + + +def collect_list(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.ArrayAgg) + + +def collect_set(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.SetAgg) + + +def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: + return Column.invoke_anonymous_function(col1, "HYPOT", col2) + + +def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: + return Column.invoke_anonymous_function(col1, "POW", col2) + + +def row_number() -> Column: + return Column(glotexp.Anonymous(this="ROW_NUMBER")) + + +def dense_rank() -> Column: + return Column(glotexp.Anonymous(this="DENSE_RANK")) + + +def rank() -> Column: + return Column(glotexp.Anonymous(this="RANK")) + + +def cume_dist() -> Column: + return Column(glotexp.Anonymous(this="CUME_DIST")) + + +def percent_rank() -> Column: + return Column(glotexp.Anonymous(this="PERCENT_RANK")) + + +def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: + return approx_count_distinct(col, rsd) + + +def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: + if rsd is None: + return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct) + return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression) + + +def coalesce(*cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(col) for col in cols] + return Column.invoke_expression_over_column( + columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None + ) + + +def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "CORR", col2) + + +def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "COVAR_POP", col2) + + +def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2) + + +def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: + if ignorenulls is not None: + return Column.invoke_anonymous_function(col, "FIRST", ignorenulls) + return Column.invoke_anonymous_function(col, "FIRST") + + +def grouping_id(*cols: ColumnOrName) -> Column: + if not cols: + return Column.invoke_anonymous_function(None, "GROUPING_ID") + if len(cols) == 1: + return Column.invoke_anonymous_function(cols[0], "GROUPING_ID") + return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:]) + + +def input_file_name() -> Column: + return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME") + + +def isnan(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ISNAN") + + +def isnull(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ISNULL") + + +def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: + if ignorenulls is not None: + return Column.invoke_anonymous_function(col, "LAST", ignorenulls) + return Column.invoke_anonymous_function(col, "LAST") + + +def monotonically_increasing_id() -> Column: + return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID") + + +def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "NANVL", col2) + + +def percentile_approx( + col: ColumnOrName, + percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]], + accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None, +) -> Column: + if accuracy: + return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy) + return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage) + + +def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: + return Column.invoke_anonymous_function(seed, "RAND") + + +def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: + return Column.invoke_anonymous_function(seed, "RANDN") + + +def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: + if scale is not None: + return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale)) + return Column.invoke_expression_over_column(col, glotexp.Round) + + +def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: + if scale is not None: + return Column.invoke_anonymous_function(col, "BROUND", scale) + return Column.invoke_anonymous_function(col, "BROUND") + + +def shiftleft(col: ColumnOrName, numBits: int) -> Column: + return Column.invoke_expression_over_column( + col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression + ) + + +def shiftLeft(col: ColumnOrName, numBits: int) -> Column: + return shiftleft(col, numBits) + + +def shiftright(col: ColumnOrName, numBits: int) -> Column: + return Column.invoke_expression_over_column( + col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression + ) + + +def shiftRight(col: ColumnOrName, numBits: int) -> Column: + return shiftright(col, numBits) + + +def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column: + return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits) + + +def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: + return shiftrightunsigned(col, numBits) + + +def expr(str: str) -> Column: + return Column(str) + + +def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: + columns = ensure_list(col) + list(cols) + expressions = [Column.ensure_col(column).expression for column in columns] + return Column(glotexp.Struct(expressions=expressions)) + + +def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: + return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase) + + +def factorial(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "FACTORIAL") + + +def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: + if default is not None: + return Column.invoke_anonymous_function(col, "LAG", offset, default) + if offset != 1: + return Column.invoke_anonymous_function(col, "LAG", offset) + return Column.invoke_anonymous_function(col, "LAG") + + +def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: + if default is not None: + return Column.invoke_anonymous_function(col, "LEAD", offset, default) + if offset != 1: + return Column.invoke_anonymous_function(col, "LEAD", offset) + return Column.invoke_anonymous_function(col, "LEAD") + + +def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: + if ignoreNulls is not None: + raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") + if offset != 1: + return Column.invoke_anonymous_function(col, "NTH_VALUE", offset) + return Column.invoke_anonymous_function(col, "NTH_VALUE") + + +def ntile(n: int) -> Column: + return Column.invoke_anonymous_function(None, "NTILE", n) + + +def current_date() -> Column: + return Column.invoke_expression_over_column(None, glotexp.CurrentDate) + + +def current_timestamp() -> Column: + return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp) + + +def date_format(col: ColumnOrName, format: str) -> Column: + return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format)) + + +def year(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Year) + + +def quarter(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "QUARTER") + + +def month(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Month) + + +def dayofweek(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DAYOFWEEK") + + +def dayofmonth(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DAYOFMONTH") + + +def dayofyear(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DAYOFYEAR") + + +def hour(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "HOUR") + + +def minute(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MINUTE") + + +def second(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SECOND") + + +def weekofyear(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "WEEKOFYEAR") + + +def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day) + + +def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: + return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression) + + +def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: + return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression) + + +def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression) + + +def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: + return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) + + +def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: + if roundOff is None: + return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) + return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) + + +def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(col, "TO_DATE", lit(format)) + return Column.invoke_anonymous_function(col, "TO_DATE") + + +def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format)) + return Column.invoke_anonymous_function(col, "TO_TIMESTAMP") + + +def trunc(col: ColumnOrName, format: str) -> Column: + return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression) + + +def date_trunc(format: str, timestamp: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression) + + +def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: + return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek)) + + +def last_day(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "LAST_DAY") + + +def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format)) + return Column.invoke_anonymous_function(col, "FROM_UNIXTIME") + + +def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format)) + return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP") + + +def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: + tz_column = tz if isinstance(tz, Column) else lit(tz) + return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column) + + +def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: + tz_column = tz if isinstance(tz, Column) else lit(tz) + return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column) + + +def timestamp_seconds(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS") + + +def window( + timeColumn: ColumnOrName, + windowDuration: str, + slideDuration: t.Optional[str] = None, + startTime: t.Optional[str] = None, +) -> Column: + if slideDuration is not None and startTime is not None: + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) + ) + if slideDuration is not None: + return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) + if startTime is not None: + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) + ) + return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration)) + + +def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column: + gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration) + return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column) + + +def crc32(col: ColumnOrName) -> Column: + column = col if isinstance(col, Column) else lit(col) + return Column.invoke_anonymous_function(column, "CRC32") + + +def md5(col: ColumnOrName) -> Column: + column = col if isinstance(col, Column) else lit(col) + return Column.invoke_anonymous_function(column, "MD5") + + +def sha1(col: ColumnOrName) -> Column: + column = col if isinstance(col, Column) else lit(col) + return Column.invoke_anonymous_function(column, "SHA1") + + +def sha2(col: ColumnOrName, numBits: int) -> Column: + column = col if isinstance(col, Column) else lit(col) + num_bits = lit(numBits) + return Column.invoke_anonymous_function(column, "SHA2", num_bits) + + +def hash(*cols: ColumnOrName) -> Column: + args = cols[1:] if len(cols) > 1 else [] + return Column.invoke_anonymous_function(cols[0], "HASH", *args) + + +def xxhash64(*cols: ColumnOrName) -> Column: + args = cols[1:] if len(cols) > 1 else [] + return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args) + + +def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column: + if errorMsg is not None: + error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) + return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col) + return Column.invoke_anonymous_function(col, "ASSERT_TRUE") + + +def raise_error(errorMsg: ColumnOrName) -> Column: + error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) + return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR") + + +def upper(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Upper) + + +def lower(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Lower) + + +def ascii(col: ColumnOrLiteral) -> Column: + return Column.invoke_anonymous_function(col, "ASCII") + + +def base64(col: ColumnOrLiteral) -> Column: + return Column.invoke_anonymous_function(col, "BASE64") + + +def unbase64(col: ColumnOrLiteral) -> Column: + return Column.invoke_anonymous_function(col, "UNBASE64") + + +def ltrim(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "LTRIM") + + +def rtrim(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "RTRIM") + + +def trim(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Trim) + + +def concat_ws(sep: str, *cols: ColumnOrName) -> Column: + columns = [Column(col) for col in cols] + return Column.invoke_expression_over_column( + None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)] + ) + + +def decode(col: ColumnOrName, charset: str) -> Column: + return Column.invoke_anonymous_function(col, "DECODE", lit(charset)) + + +def encode(col: ColumnOrName, charset: str) -> Column: + return Column.invoke_anonymous_function(col, "ENCODE", lit(charset)) + + +def format_number(col: ColumnOrName, d: int) -> Column: + return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d)) + + +def format_string(format: str, *cols: ColumnOrName) -> Column: + format_col = lit(format) + columns = [Column.ensure_col(x) for x in cols] + return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns) + + +def instr(col: ColumnOrName, substr: str) -> Column: + return Column.invoke_anonymous_function(col, "INSTR", lit(substr)) + + +def overlay( + src: ColumnOrName, + replace: ColumnOrName, + pos: t.Union[ColumnOrName, int], + len: t.Optional[t.Union[ColumnOrName, int]] = None, +) -> Column: + if len is not None: + return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len) + return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos) + + +def sentences( + string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None +) -> Column: + if language is not None and country is not None: + return Column.invoke_anonymous_function(string, "SENTENCES", language, country) + if language is not None: + return Column.invoke_anonymous_function(string, "SENTENCES", language) + if country is not None: + return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country) + return Column.invoke_anonymous_function(string, "SENTENCES") + + +def substring(str: ColumnOrName, pos: int, len: int) -> Column: + return Column.ensure_col(str).substr(pos, len) + + +def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: + return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count)) + + +def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: + return Column.invoke_expression_over_column( + left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression + ) + + +def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: + substr_col = lit(substr) + pos_column = lit(pos) + str_column = Column.ensure_col(str) + if pos is not None: + return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column) + return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column) + + +def lpad(col: ColumnOrName, len: int, pad: str) -> Column: + return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad)) + + +def rpad(col: ColumnOrName, len: int, pad: str) -> Column: + return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad)) + + +def repeat(col: ColumnOrName, n: int) -> Column: + return Column.invoke_anonymous_function(col, "REPEAT", n) + + +def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: + if limit is not None: + return Column.invoke_expression_over_column( + str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression + ) + return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression) + + +def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: + if idx is not None: + return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx) + return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern)) + + +def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: + return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement)) + + +def initcap(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Initcap) + + +def soundex(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SOUNDEX") + + +def bin(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "BIN") + + +def hex(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "HEX") + + +def unhex(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "UNHEX") + + +def length(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Length) + + +def octet_length(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "OCTET_LENGTH") + + +def bit_length(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "BIT_LENGTH") + + +def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: + return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace)) + + +def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: + cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore + cols = [Column.ensure_col(col).expression for col in cols] # type: ignore + return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols) + + +def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: + cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore + return Column.invoke_expression_over_column( + None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression + ) + + +def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2) + + +def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) + + +def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) + + +def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: + start_col = start if isinstance(start, Column) else lit(start) + length_col = length if isinstance(length, Column) else lit(length) + return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) + + +def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: + if null_replacement is not None: + return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) + return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) + + +def concat(*cols: ColumnOrName) -> Column: + if len(cols) == 1: + return Column.invoke_anonymous_function(cols[0], "CONCAT") + return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) + + +def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col) + + +def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col) + + +def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col) + + +def array_distinct(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT") + + +def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2)) + + +def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2)) + + +def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2)) + + +def explode(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Explode) + + +def posexplode(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Posexplode) + + +def explode_outer(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "EXPLODE_OUTER") + + +def posexplode_outer(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER") + + +def get_json_object(col: ColumnOrName, path: str) -> Column: + return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression) + + +def json_tuple(col: ColumnOrName, *fields: str) -> Column: + return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields]) + + +def from_json( + col: ColumnOrName, + schema: t.Union[Column, str], + options: t.Optional[t.Dict[str, str]] = None, +) -> Column: + schema = schema if isinstance(schema, Column) else lit(schema) + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col) + return Column.invoke_anonymous_function(col, "FROM_JSON", schema) + + +def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "TO_JSON", options_col) + return Column.invoke_anonymous_function(col, "TO_JSON") + + +def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON") + + +def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV") + + +def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "TO_CSV", options_col) + return Column.invoke_anonymous_function(col, "TO_CSV") + + +def size(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.ArraySize) + + +def array_min(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ARRAY_MIN") + + +def array_max(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ARRAY_MAX") + + +def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: + if asc is not None: + return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc)) + return Column.invoke_anonymous_function(col, "SORT_ARRAY") + + +def array_sort(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.ArraySort) + + +def shuffle(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SHUFFLE") + + +def reverse(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "REVERSE") + + +def flatten(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "FLATTEN") + + +def map_keys(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_KEYS") + + +def map_values(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_VALUES") + + +def map_entries(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_ENTRIES") + + +def map_from_entries(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES") + + +def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column: + count_col = count if isinstance(count, Column) else lit(count) + return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col) + + +def array_zip(*cols: ColumnOrName) -> Column: + if len(cols) == 1: + return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP") + return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:]) + + +def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: + columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore + if len(columns) == 1: + return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT") + return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) + + +def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: + if step is not None: + return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) + return Column.invoke_anonymous_function(start, "SEQUENCE", stop) + + +def from_csv( + col: ColumnOrName, + schema: t.Union[Column, str], + options: t.Optional[t.Dict[str, str]] = None, +) -> Column: + schema = schema if isinstance(schema, Column) else lit(schema) + if options is not None: + option_cols = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols) + return Column.invoke_anonymous_function(col, "FROM_CSV", schema) + + +def aggregate( + col: ColumnOrName, + initialValue: ColumnOrName, + merge: t.Callable[[Column, Column], Column], + finish: t.Optional[t.Callable[[Column], Column]] = None, + accumulator_name: str = "acc", + target_row_name: str = "x", +) -> Column: + merge_exp = glotexp.Lambda( + this=merge(Column(accumulator_name), Column(target_row_name)).expression, + expressions=[ + glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)), + glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)), + ], + ) + if finish is not None: + finish_exp = glotexp.Lambda( + this=finish(Column(accumulator_name)).expression, + expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))], + ) + return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) + return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) + + +def transform( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], + target_row_name: str = "x", + row_count_name: str = "i", +) -> Column: + num_arguments = len(signature(f).parameters) + expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))] + columns = [Column(target_row_name)] + if num_arguments > 1: + columns.append(Column(row_count_name)) + expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name))) + + f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions) + return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) + + +def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: + f_expression = glotexp.Lambda( + this=f(Column(target_row_name)).expression, + expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))], + ) + return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression)) + + +def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: + f_expression = glotexp.Lambda( + this=f(Column(target_row_name)).expression, + expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))], + ) + + return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) + + +def filter( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], + target_row_name: str = "x", + row_count_name: str = "i", +) -> Column: + num_arguments = len(signature(f).parameters) + expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))] + columns = [Column(target_row_name)] + if num_arguments > 1: + columns.append(Column(row_count_name)) + expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name))) + + f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions) + return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression)) + + +def zip_with( + left: ColumnOrName, + right: ColumnOrName, + f: t.Callable[[Column, Column], Column], + left_name: str = "x", + right_name: str = "y", +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(left_name), Column(right_name)).expression, + expressions=[ + glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)), + glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)), + ], + ) + + return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) + + +def transform_keys( + col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value_name)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), + ], + ) + return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression)) + + +def transform_values( + col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value_name)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), + ], + ) + return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression)) + + +def map_filter( + col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value_name)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), + ], + ) + return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression)) + + +def map_zip_with( + col1: ColumnOrName, + col2: ColumnOrName, + f: t.Union[t.Callable[[Column, Column, Column], Column]], + key_name: str = "k", + value1: str = "v1", + value2: str = "v2", +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value1), Column(value2)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)), + glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)), + ], + ) + return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) + + +def _lambda_quoted(value: str) -> t.Optional[bool]: + return False if value == "_" else None diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py new file mode 100644 index 0000000..947aace --- /dev/null +++ b/sqlglot/dataframe/sql/group.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import typing as t + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.operations import Operation, operation + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.dataframe import DataFrame + + +class GroupedData: + def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): + self._df = df.copy() + self.spark = df.spark + self.last_op = last_op + self.group_by_cols = group_by_cols + + def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]: + func_name = func_name.lower() + return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] + + @operation(Operation.SELECT) + def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: + columns = ( + [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] + if isinstance(exprs[0], dict) + else exprs + ) + cols = self._df._ensure_and_normalize_cols(columns) + + expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select( + *[x.expression for x in self.group_by_cols + cols], append=False + ) + return self._df.copy(expression=expression) + + def count(self) -> DataFrame: + return self.agg(F.count("*").alias("count")) + + def mean(self, *cols: str) -> DataFrame: + return self.avg(*cols) + + def avg(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("avg", cols)) + + def max(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("max", cols)) + + def min(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("min", cols)) + + def sum(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("sum", cols)) + + def pivot(self, *cols: str) -> DataFrame: + raise NotImplementedError("Sum distinct is not currently implemented") diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py new file mode 100644 index 0000000..1513946 --- /dev/null +++ b/sqlglot/dataframe/sql/normalize.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.helper import ensure_list + +NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.session import SparkSession + + +def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]): + expr = ensure_list(expr) + expressions = _ensure_expressions(expr) + for expression in expressions: + identifiers = expression.find_all(exp.Identifier) + for identifier in identifiers: + replace_alias_name_with_cte_name(spark, expression_context, identifier) + replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) + + +def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): + if id.alias_or_name in spark.name_to_sequence_id_mapping: + for cte in reversed(expression_context.ctes): + if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: + _set_alias_name(id, cte.alias_or_name) + break + + +def replace_branch_and_sequence_ids_with_cte_name( + spark: SparkSession, expression_context: exp.Select, id: exp.Identifier +): + if id.alias_or_name in spark.known_ids: + # Check if we have a join and if both the tables in that join share a common branch id + # If so we need to have this reference the left table by default unless the id is a sequence + # id then it keeps that reference. This handles the weird edge case in spark that shouldn't + # be common in practice + if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: + join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] + ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] + if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: + assert len(ctes_in_join) == 2 + _set_alias_name(id, ctes_in_join[0].alias_or_name) + return + + for cte in reversed(expression_context.ctes): + if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]): + _set_alias_name(id, cte.alias_or_name) + return + + +def _set_alias_name(id: exp.Identifier, name: str): + id.set("this", name) + + +def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: + values = ensure_list(values) + results = [] + for value in values: + if isinstance(value, str): + results.append(Column.ensure_col(value).expression) + elif isinstance(value, Column): + results.append(value.expression) + elif isinstance(value, exp.Expression): + results.append(value) + else: + raise ValueError(f"Got an invalid type to normalize: {type(value)}") + return results diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py new file mode 100644 index 0000000..d51335c --- /dev/null +++ b/sqlglot/dataframe/sql/operations.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import functools +import typing as t +from enum import IntEnum + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.group import GroupedData + + +class Operation(IntEnum): + INIT = -1 + NO_OP = 0 + FROM = 1 + WHERE = 2 + GROUP_BY = 3 + HAVING = 4 + SELECT = 5 + ORDER_BY = 6 + LIMIT = 7 + + +def operation(op: Operation): + """ + Decorator used around DataFrame methods to indicate what type of operation is being performed from the + ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. + included with the previous operation. + + Ex: After a user does a join we want to allow them to select which columns for the different + tables that they want to carry through to the following operation. If we put that join in + a CTE preemptively then the user would not have a chance to select which column they want + in cases where there is overlap in names. + """ + + def decorator(func: t.Callable): + @functools.wraps(func) + def wrapper(self: DataFrame, *args, **kwargs): + if self.last_op == Operation.INIT: + self = self._convert_leaf_to_cte() + self.last_op = Operation.NO_OP + last_op = self.last_op + new_op = op if op != Operation.NO_OP else last_op + if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT): + self = self._convert_leaf_to_cte() + df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) + df.last_op = new_op # type: ignore + return df + + wrapper.__wrapped__ = func # type: ignore + return wrapper + + return decorator diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py new file mode 100644 index 0000000..4830035 --- /dev/null +++ b/sqlglot/dataframe/sql/readwriter.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import typing as t + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.helper import object_to_dict + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.session import SparkSession + + +class DataFrameReader: + def __init__(self, spark: SparkSession): + self.spark = spark + + def table(self, tableName: str) -> DataFrame: + from sqlglot.dataframe.sql.dataframe import DataFrame + + sqlglot.schema.add_table(tableName) + return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName))) + + +class DataFrameWriter: + def __init__( + self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False + ): + self._df = df + self._spark = spark or df.spark + self._mode = mode + self._by_name = by_name + + def copy(self, **kwargs) -> DataFrameWriter: + return DataFrameWriter( + **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()} + ) + + def sql(self, **kwargs) -> t.List[str]: + return self._df.sql(**kwargs) + + def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: + return self.copy(_mode=saveMode) + + @property + def byName(self): + return self.copy(by_name=True) + + def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: + output_expression_container = exp.Insert( + **{ + "this": exp.to_table(tableName), + "overwrite": overwrite, + } + ) + df = self._df.copy(output_expression_container=output_expression_container) + if self._by_name: + columns = sqlglot.schema.column_names(tableName, only_visible=True) + df = df._convert_leaf_to_cte().select(*columns) + + return self.copy(_df=df) + + def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): + if format is not None: + raise NotImplementedError("Providing Format in the save as table is not supported") + exists, replace, mode = None, None, mode or str(self._mode) + if mode == "append": + return self.insertInto(name) + if mode == "ignore": + exists = True + if mode == "overwrite": + replace = True + output_expression_container = exp.Create( + this=exp.to_table(name), + kind="TABLE", + exists=exists, + replace=replace, + ) + return self.copy(_df=self._df.copy(output_expression_container=output_expression_container)) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py new file mode 100644 index 0000000..1ea86d1 --- /dev/null +++ b/sqlglot/dataframe/sql/session.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import typing as t +import uuid +from collections import defaultdict + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.dataframe.sql.readwriter import DataFrameReader +from sqlglot.dataframe.sql.types import StructType +from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput + + +class SparkSession: + known_ids: t.ClassVar[t.Set[str]] = set() + known_branch_ids: t.ClassVar[t.Set[str]] = set() + known_sequence_ids: t.ClassVar[t.Set[str]] = set() + name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) + + def __init__(self): + self.incrementing_id = 1 + + def __getattr__(self, name: str) -> SparkSession: + return self + + def __call__(self, *args, **kwargs) -> SparkSession: + return self + + @property + def read(self) -> DataFrameReader: + return DataFrameReader(self) + + def table(self, tableName: str) -> DataFrame: + return self.read.table(tableName) + + def createDataFrame( + self, + data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], + schema: t.Optional[SchemaInput] = None, + samplingRatio: t.Optional[float] = None, + verifySchema: bool = False, + ) -> DataFrame: + from sqlglot.dataframe.sql.dataframe import DataFrame + + if samplingRatio is not None or verifySchema: + raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") + if schema is not None and ( + not isinstance(schema, (StructType, str, list)) + or (isinstance(schema, list) and not isinstance(schema[0], str)) + ): + raise NotImplementedError("Only schema of either list or string of list supported") + if not data: + raise ValueError("Must provide data to create into a DataFrame") + + column_mapping: t.Dict[str, t.Optional[str]] + if schema is not None: + column_mapping = get_column_mapping_from_schema_input(schema) + elif isinstance(data[0], dict): + column_mapping = {col_name.strip(): None for col_name in data[0]} + else: + column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} + + data_expressions = [ + exp.Tuple( + expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values())) + ) + for row in data + ] + + sel_columns = [ + F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression + for name, data_type in column_mapping.items() + ] + + select_kwargs = { + "expressions": sel_columns, + "from": exp.From( + expressions=[ + exp.Subquery( + this=exp.Values(expressions=data_expressions), + alias=exp.TableAlias( + this=exp.to_identifier(self._auto_incrementing_name), + columns=[exp.to_identifier(col_name) for col_name in column_mapping], + ), + ) + ] + ), + } + + sel_expression = exp.Select(**select_kwargs) + return DataFrame(self, sel_expression) + + def sql(self, sqlQuery: str) -> DataFrame: + expression = sqlglot.parse_one(sqlQuery, read="spark") + if isinstance(expression, exp.Select): + df = DataFrame(self, expression) + df = df._convert_leaf_to_cte() + elif isinstance(expression, (exp.Create, exp.Insert)): + select_expression = expression.expression.copy() + if isinstance(expression, exp.Insert): + select_expression.set("with", expression.args.get("with")) + expression.set("with", None) + del expression.args["expression"] + df = DataFrame(self, select_expression, output_expression_container=expression) + df = df._convert_leaf_to_cte() + else: + raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.") + return df + + @property + def _auto_incrementing_name(self) -> str: + name = f"a{self.incrementing_id}" + self.incrementing_id += 1 + return name + + @property + def _random_name(self) -> str: + return f"a{str(uuid.uuid4())[:8]}" + + @property + def _random_branch_id(self) -> str: + id = self._random_id + self.known_branch_ids.add(id) + return id + + @property + def _random_sequence_id(self): + id = self._random_id + self.known_sequence_ids.add(id) + return id + + @property + def _random_id(self) -> str: + id = f"a{str(uuid.uuid4())[:8]}" + self.known_ids.add(id) + return id + + @property + def _join_hint_names(self) -> t.Set[str]: + return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} + + def _add_alias_to_mapping(self, name: str, sequence_id: str): + self.name_to_sequence_id_mapping[name].append(sequence_id) diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py new file mode 100644 index 0000000..b3dcc12 --- /dev/null +++ b/sqlglot/dataframe/sql/transforms.py @@ -0,0 +1,9 @@ +import typing as t + +from sqlglot import expressions as exp + + +def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]): + if isinstance(node, exp.Identifier) and node in replacement_mapping: + node = node.replace(replacement_mapping[node].copy()) + return node diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py new file mode 100644 index 0000000..dc5c05a --- /dev/null +++ b/sqlglot/dataframe/sql/types.py @@ -0,0 +1,208 @@ +import typing as t + + +class DataType: + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: t.Any) -> bool: + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other: t.Any) -> bool: + return not self.__eq__(other) + + def __str__(self) -> str: + return self.typeName() + + @classmethod + def typeName(cls) -> str: + return cls.__name__[:-4].lower() + + def simpleString(self) -> str: + return str(self) + + def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]: + return str(self) + + +class DataTypeWithLength(DataType): + def __init__(self, length: int): + self.length = length + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.length})" + + def __str__(self) -> str: + return f"{self.typeName()}({self.length})" + + +class StringType(DataType): + pass + + +class CharType(DataTypeWithLength): + pass + + +class VarcharType(DataTypeWithLength): + pass + + +class BinaryType(DataType): + pass + + +class BooleanType(DataType): + pass + + +class DateType(DataType): + pass + + +class TimestampType(DataType): + pass + + +class TimestampNTZType(DataType): + @classmethod + def typeName(cls) -> str: + return "timestamp_ntz" + + +class DecimalType(DataType): + def __init__(self, precision: int = 10, scale: int = 0): + self.precision = precision + self.scale = scale + + def simpleString(self) -> str: + return f"decimal({self.precision}, {self.scale})" + + def jsonValue(self) -> str: + return f"decimal({self.precision}, {self.scale})" + + def __repr__(self) -> str: + return f"DecimalType({self.precision}, {self.scale})" + + +class DoubleType(DataType): + pass + + +class FloatType(DataType): + pass + + +class ByteType(DataType): + def __str__(self) -> str: + return "tinyint" + + +class IntegerType(DataType): + def __str__(self) -> str: + return "int" + + +class LongType(DataType): + def __str__(self) -> str: + return "bigint" + + +class ShortType(DataType): + def __str__(self) -> str: + return "smallint" + + +class ArrayType(DataType): + def __init__(self, elementType: DataType, containsNull: bool = True): + self.elementType = elementType + self.containsNull = containsNull + + def __repr__(self) -> str: + return f"ArrayType({self.elementType, str(self.containsNull)}" + + def simpleString(self) -> str: + return f"array<{self.elementType.simpleString()}>" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return { + "type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull, + } + + +class MapType(DataType): + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self) -> str: + return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})" + + def simpleString(self) -> str: + return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return { + "type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull, + } + + +class StructField(DataType): + def __init__( + self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None + ): + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def __repr__(self) -> str: + return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})" + + def simpleString(self) -> str: + return f"{self.name}:{self.dataType.simpleString()}" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return { + "name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata, + } + + +class StructType(DataType): + def __init__(self, fields: t.Optional[t.List[StructField]] = None): + if not fields: + self.fields = [] + self.names = [] + else: + self.fields = fields + self.names = [f.name for f in fields] + + def __iter__(self) -> t.Iterator[StructField]: + return iter(self.fields) + + def __len__(self) -> int: + return len(self.fields) + + def __repr__(self) -> str: + return f"StructType({', '.join(str(field) for field in self)})" + + def simpleString(self) -> str: + return f"struct<{', '.join(x.simpleString() for x in self)}>" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]} + + def fieldNames(self) -> t.List[str]: + return list(self.names) diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py new file mode 100644 index 0000000..575d18a --- /dev/null +++ b/sqlglot/dataframe/sql/util.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import types + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import SchemaInput + + +def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]: + if isinstance(schema, dict): + return schema + elif isinstance(schema, str): + col_name_type_strs = [x.strip() for x in schema.split(",")] + return { + name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() + for name_type_str in col_name_type_strs + } + elif isinstance(schema, types.StructType): + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema} + return {x.strip(): None for x in schema} # type: ignore + + +def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]: + if not expression.args.get("joins"): + return [] + + left_table = expression.args["from"].args["expressions"][0] + other_tables = [join.this for join in expression.args["joins"]] + return [left_table] + other_tables diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py new file mode 100644 index 0000000..842f366 --- /dev/null +++ b/sqlglot/dataframe/sql/window.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import sys +import typing as t + +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.helper import flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrName + + +class Window: + _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 + _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 + _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) + _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) + + unboundedPreceding: int = _JAVA_MIN_LONG + + unboundedFollowing: int = _JAVA_MAX_LONG + + currentRow: int = 0 + + @classmethod + def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + return WindowSpec().partitionBy(*cols) + + @classmethod + def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + return WindowSpec().orderBy(*cols) + + @classmethod + def rowsBetween(cls, start: int, end: int) -> WindowSpec: + return WindowSpec().rowsBetween(start, end) + + @classmethod + def rangeBetween(cls, start: int, end: int) -> WindowSpec: + return WindowSpec().rangeBetween(start, end) + + +class WindowSpec: + def __init__(self, expression: exp.Expression = exp.Window()): + self.expression = expression + + def copy(self): + return WindowSpec(self.expression.copy()) + + def sql(self, **kwargs) -> str: + return self.expression.sql(dialect="spark", **kwargs) + + def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + from sqlglot.dataframe.sql.column import Column + + cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [Column.ensure_col(x).expression for x in cols] + window_spec = self.copy() + partition_by_expressions = window_spec.expression.args.get("partition_by", []) + partition_by_expressions.extend(expressions) + window_spec.expression.set("partition_by", partition_by_expressions) + return window_spec + + def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + from sqlglot.dataframe.sql.column import Column + + cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [Column.ensure_col(x).expression for x in cols] + window_spec = self.copy() + if window_spec.expression.args.get("order") is None: + window_spec.expression.set("order", exp.Order(expressions=[])) + order_by = window_spec.expression.args["order"].expressions + order_by.extend(expressions) + window_spec.expression.args["order"].set("expressions", order_by) + return window_spec + + def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: + kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None} + if start == Window.currentRow: + kwargs["start"] = "CURRENT ROW" + else: + kwargs = { + **kwargs, + **{ + "start_side": "PRECEDING", + "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression, + }, + } + if end == Window.currentRow: + kwargs["end"] = "CURRENT ROW" + else: + kwargs = { + **kwargs, + **{ + "end_side": "FOLLOWING", + "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression, + }, + } + return kwargs + + def rowsBetween(self, start: int, end: int) -> WindowSpec: + window_spec = self.copy() + spec = self._calc_start_end(start, end) + spec["kind"] = "ROWS" + window_spec.expression.set( + "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + ) + return window_spec + + def rangeBetween(self, start: int, end: int) -> WindowSpec: + window_spec = self.copy() + spec = self._calc_start_end(start, end) + spec["kind"] = "RANGE" + window_spec.expression.set( + "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + ) + return window_spec diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 86e46cf..62d042e 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -78,6 +78,16 @@ def _create_sql(self, expression): class BigQuery(Dialect): unnest_column_only = True + time_mapping = { + "%M": "%-M", + "%d": "%-d", + "%m": "%-m", + "%y": "%-y", + "%H": "%-H", + "%I": "%-I", + "%S": "%-S", + "%j": "%-j", + } class Tokenizer(Tokenizer): QUOTES = [ @@ -113,6 +123,7 @@ class BigQuery(Dialect): "DATETIME_SUB": _date_add(exp.DatetimeSub), "TIME_SUB": _date_add(exp.TimeSub), "TIMESTAMP_SUB": _date_add(exp.TimestampSub), + "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)), } NO_PAREN_FUNCTIONS = { @@ -137,6 +148,7 @@ class BigQuery(Dialect): exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.ILike: no_ilike_sql, + exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 531c72a..46661cf 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -2,7 +2,7 @@ from enum import Enum from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import list_get +from sqlglot.helper import flatten, list_get from sqlglot.parser import Parser from sqlglot.time import format_time from sqlglot.tokens import Tokenizer @@ -67,6 +67,11 @@ class _Dialect(type): klass.generator_class.TRANSFORMS[ exp.HexString ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" + if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS: + be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] + klass.generator_class.TRANSFORMS[ + exp.ByteString + ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}" return klass @@ -176,11 +181,7 @@ class Dialect(metaclass=_Dialect): def rename_func(name): def _rename(self, expression): - args = ( - expression.expressions - if isinstance(expression, exp.Func) and expression.is_var_len_args - else expression.args.values() - ) + args = flatten(expression.args.values()) return f"{name}({self.format_args(*args)})" return _rename diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 8888df8..0810e0c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -121,6 +121,9 @@ class Hive(Dialect): "ss": "%S", "s": "%-S", "S": "%f", + "a": "%p", + "DD": "%j", + "D": "%-j", } date_format = "'yyyy-MM-dd'" @@ -200,6 +203,7 @@ class Hive(Dialect): exp.AnonymousProperty: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.ArrayConcat: rename_func("CONCAT"), exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort, exp.With: no_recursive_cte_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 8449379..524390f 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -97,6 +97,8 @@ class MySQL(Dialect): "%s": "%S", "%S": "%S", "%u": "%W", + "%k": "%-H", + "%l": "%-I", } class Tokenizer(Tokenizer): @@ -145,6 +147,9 @@ class MySQL(Dialect): "_TIS620": TokenType.INTRODUCER, "_UCS2": TokenType.INTRODUCER, "_UJIS": TokenType.INTRODUCER, + # https://dev.mysql.com/doc/refman/8.0/en/string-literals.html + "N": TokenType.INTRODUCER, + "n": TokenType.INTRODUCER, "_UTF8": TokenType.INTRODUCER, "_UTF16": TokenType.INTRODUCER, "_UTF16LE": TokenType.INTRODUCER, diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 8041ff0..144dba5 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -80,17 +80,12 @@ class Oracle(Dialect): sep="", ) - def alias_sql(self, expression): - if isinstance(expression.this, exp.Table): - to_sql = self.sql(expression, "alias") - # oracle does not allow "AS" between table and alias - to_sql = f" {to_sql}" if to_sql else "" - return f"{self.sql(expression, 'this')}{to_sql}" - return super().alias_sql(expression) - def offset_sql(self, expression): return f"{super().offset_sql(expression)} ROWS" + def table_sql(self, expression): + return super().table_sql(expression, sep=" ") + class Tokenizer(Tokenizer): KEYWORDS = { **Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c91ff4b..459e926 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -163,6 +163,7 @@ class Postgres(Dialect): class Tokenizer(Tokenizer): BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] + BYTE_STRINGS = [("e'", "'"), ("E'", "'")] KEYWORDS = { **Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, @@ -176,6 +177,11 @@ class Postgres(Dialect): "SMALLSERIAL": TokenType.SMALLSERIAL, "UUID": TokenType.UUID, } + QUOTES = ["'", "$$"] + SINGLE_TOKENS = { + **Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } class Parser(Parser): STRICT_CAST = False diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 8dfb2fd..41c0db1 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -172,6 +172,7 @@ class Presto(Dialect): **transforms.UNALIAS_GROUP, exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 19a427c..627258f 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -69,6 +69,35 @@ def _unix_to_time(self, expression): raise ValueError("Improper scale for timestamp") +# https://docs.snowflake.com/en/sql-reference/functions/date_part.html +# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts +def _parse_date_part(self): + this = self._parse_var() or self._parse_type() + self._match(TokenType.COMMA) + expression = self._parse_bitwise() + + name = this.name.upper() + if name.startswith("EPOCH"): + if name.startswith("EPOCH_MILLISECOND"): + scale = 10**3 + elif name.startswith("EPOCH_MICROSECOND"): + scale = 10**6 + elif name.startswith("EPOCH_NANOSECOND"): + scale = 10**9 + else: + scale = None + + ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) + to_unix = self.expression(exp.TimeToUnix, this=ts) + + if scale: + to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) + + return to_unix + + return self.expression(exp.Extract, this=this, expression=expression) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -115,7 +144,7 @@ class Snowflake(Dialect): FUNCTION_PARSERS = { **Parser.FUNCTION_PARSERS, - "DATE_PART": lambda self: self._parse_extract(), + "DATE_PART": _parse_date_part, } FUNC_TOKENS = { @@ -161,9 +190,11 @@ class Snowflake(Dialect): class Generator(Generator): TRANSFORMS = { **Generator.TRANSFORMS, + exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.If: rename_func("IFF"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: _unix_to_time, + exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Array: inline_array_sql, exp.StrPosition: rename_func("POSITION"), exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 95a7ab4..6bf4ff0 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,9 +1,5 @@ from sqlglot import exp -from sqlglot.dialects.dialect import ( - create_with_partitions_sql, - no_ilike_sql, - rename_func, -) +from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func from sqlglot.dialects.hive import Hive from sqlglot.helper import list_get from sqlglot.parser import Parser @@ -98,13 +94,14 @@ class Spark(Hive): } TRANSFORMS = { - **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}}, + **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}}, + exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.DateTrunc: rename_func("TRUNC"), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", - exp.ILike: no_ilike_sql, exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: _unix_to_time, @@ -112,6 +109,8 @@ class Spark(Hive): exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", + exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", + exp.VariancePop: rename_func("VAR_POP"), } WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 73b232e..1f2e50d 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -32,6 +32,11 @@ class TSQL(Dialect): } class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "CHARINDEX": exp.StrPosition.from_arg_list, + } + def _parse_convert(self): to = self._parse_types() self._match(TokenType.COMMA) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 72b0558..9c49dd1 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -19,6 +19,7 @@ ENV = { "datetime": datetime, "locals": locals, "re": re, + "bool": bool, "float": float, "int": int, "str": str, diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 8ef6cf0..fcb016b 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -80,9 +80,10 @@ class PythonExecutor: source = step.source if isinstance(source, exp.Expression): - source = source.this.name or source.alias + source = source.name or source.alias else: source = step.name + condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) @@ -121,7 +122,7 @@ class PythonExecutor: source = step.source alias = source.alias - with csv_reader(source.this) as reader: + with csv_reader(source) as reader: columns = next(reader) table = Table(columns) context = self.context({alias: table}) @@ -308,7 +309,7 @@ def _interval_py(self, expression): def _like_py(self, expression): this = self.sql(expression, "this") expression = self.sql(expression, "expression") - return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})""" + return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))""" def _ordered_py(self, expression): @@ -330,6 +331,7 @@ class Python(Dialect): exp.Cast: _cast_py, exp.Column: _column_py, exp.EQ: lambda self, e: self.binary(e, "=="), + exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}", exp.Interval: _interval_py, exp.Is: lambda self, e: self.binary(e, "is"), exp.Like: _like_py, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 39f4452..f7717c8 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -11,6 +11,7 @@ from sqlglot.helper import ( camel_to_snake_case, ensure_list, list_get, + split_num_words, subclasses, ) @@ -108,6 +109,8 @@ class Expression(metaclass=_Expression): @property def alias_or_name(self): + if isinstance(self, Null): + return "NULL" return self.alias or self.name def __deepcopy__(self, memo): @@ -659,6 +662,10 @@ class HexString(Condition): pass +class ByteString(Condition): + pass + + class Column(Condition): arg_types = {"this": True, "table": False} @@ -725,7 +732,7 @@ class Constraint(Expression): class Delete(Expression): - arg_types = {"with": False, "this": True, "where": False} + arg_types = {"with": False, "this": True, "using": False, "where": False} class Drop(Expression): @@ -1192,6 +1199,7 @@ QUERY_MODIFIERS = { class Table(Expression): arg_types = { "this": True, + "alias": False, "db": False, "catalog": False, "laterals": False, @@ -1323,6 +1331,7 @@ class Select(Subqueryable): *expressions (str or Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Group`. + If nothing is passed in then a group by is not applied to the expression append (bool): if `True`, add to any existing expressions. Otherwise, this flattens all the `Group` expression into a single expression. dialect (str): the dialect used to parse the input expression. @@ -1332,6 +1341,8 @@ class Select(Subqueryable): Returns: Select: the modified expression. """ + if not expressions: + return self if not copy else self.copy() return _apply_child_list_builder( *expressions, instance=self, @@ -2239,6 +2250,11 @@ class ArrayAny(Func): arg_types = {"this": True, "expression": True} +class ArrayConcat(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + class ArrayContains(Func): arg_types = {"this": True, "expression": True} @@ -2570,7 +2586,7 @@ class SortArray(Func): class Split(Func): - arg_types = {"this": True, "expression": True} + arg_types = {"this": True, "expression": True, "limit": False} # Start may be omitted in the case of postgres @@ -3209,29 +3225,49 @@ def to_identifier(alias, quoted=None): return identifier -def to_table(sql_path, **kwargs): +def to_table(sql_path: str, **kwargs) -> Table: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. - Example: - >>> to_table('catalog.db.table_name').sql() - 'catalog.db.table_name' + + If a table is passed in then that table is returned. Args: - sql_path(str): `[catalog].[schema].[table]` string + sql_path(str|Table): `[catalog].[schema].[table]` string Returns: Table: A table expression """ - table_parts = sql_path.split(".") - catalog, db, table_name = [ - to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts - ] + if sql_path is None or isinstance(sql_path, Table): + return sql_path + if not isinstance(sql_path, str): + raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") + + catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)] return Table(this=table_name, db=db, catalog=catalog, **kwargs) +def to_column(sql_path: str, **kwargs) -> Column: + """ + Create a column from a `[table].[column]` sql path. Schema is optional. + + If a column is passed in then that column is returned. + + Args: + sql_path: `[table].[column]` string + Returns: + Table: A column expression + """ + if sql_path is None or isinstance(sql_path, Column): + return sql_path + if not isinstance(sql_path, str): + raise ValueError(f"Invalid type provided for column: {type(sql_path)}") + table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)] + return Column(this=column_name, table=table_name, **kwargs) + + def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): """ Create an Alias expression. - Expample: + Example: >>> alias_('foo', 'bar').sql() 'foo AS bar' @@ -3249,7 +3285,16 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): """ exp = maybe_parse(expression, dialect=dialect, **opts) alias = to_identifier(alias, quoted=quoted) - alias = TableAlias(this=alias) if table else alias + + if table: + expression.set("alias", TableAlias(this=alias)) + return expression + + # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in + # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node + # for the complete Window expression. + # + # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls if "alias" in exp.arg_types and not isinstance(exp, Window): exp = exp.copy() @@ -3295,7 +3340,7 @@ def column(col, table=None, quoted=None): ) -def table_(table, db=None, catalog=None, quoted=None): +def table_(table, db=None, catalog=None, quoted=None, alias=None): """Build a Table. Args: @@ -3310,6 +3355,7 @@ def table_(table, db=None, catalog=None, quoted=None): this=to_identifier(table, quoted=quoted), db=to_identifier(db, quoted=quoted), catalog=to_identifier(catalog, quoted=quoted), + alias=TableAlias(this=to_identifier(alias)) if alias else None, ) @@ -3453,7 +3499,7 @@ def replace_tables(expression, mapping): Examples: >>> from sqlglot import exp, parse_one >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() - 'SELECT * FROM "c"' + 'SELECT * FROM c' Returns: The mapped expression @@ -3463,7 +3509,10 @@ def replace_tables(expression, mapping): if isinstance(node, Table): new_name = mapping.get(table_name(node)) if new_name: - return table_(*reversed(new_name.split(".")), quoted=True) + return to_table( + new_name, + **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")}, + ) return node return expression.transform(_replace_tables) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index bb7fd71..6decd16 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -47,6 +47,8 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 + annotations: Whether or not to show annotations in the SQL. + Default: True """ TRANSFORMS = { @@ -116,6 +118,7 @@ class Generator: "_escaped_quote_end", "_leading_comma", "_max_text_width", + "_annotations", ) def __init__( @@ -141,6 +144,7 @@ class Generator: max_unsupported=3, leading_comma=False, max_text_width=80, + annotations=True, ): import sqlglot @@ -169,6 +173,7 @@ class Generator: self._escaped_quote_end = self.escape + self.quote_end self._leading_comma = leading_comma self._max_text_width = max_text_width + self._annotations = annotations def generate(self, expression): """ @@ -275,7 +280,9 @@ class Generator: raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") def annotation_sql(self, expression): - return f"{self.sql(expression, 'expression')} # {expression.name.strip()}" + if self._annotations: + return f"{self.sql(expression, 'expression')} # {expression.name}" + return self.sql(expression, "expression") def uncache_sql(self, expression): table = self.sql(expression, "this") @@ -423,8 +430,11 @@ class Generator: def delete_sql(self, expression): this = self.sql(expression, "this") + using_sql = ( + f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else "" + ) where_sql = self.sql(expression, "where") - sql = f"DELETE FROM {this}{where_sql}" + sql = f"DELETE FROM {this}{using_sql}{where_sql}" return self.prepend_ctes(expression, sql) def drop_sql(self, expression): @@ -571,7 +581,7 @@ class Generator: null = f" NULL DEFINED AS {null}" if null else "" return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" - def table_sql(self, expression): + def table_sql(self, expression, sep=" AS "): table = ".".join( part for part in [ @@ -582,13 +592,20 @@ class Generator: if part ) + alias = self.sql(expression, "alias") + alias = f"{sep}{alias}" if alias else "" laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") - return f"{table}{laterals}{joins}{pivots}" + + if alias and pivots: + pivots = f"{pivots}{alias}" + alias = "" + + return f"{table}{alias}{laterals}{joins}{pivots}" def tablesample_sql(self, expression): - if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): + if self.alias_post_tablesample and expression.this.alias: this = self.sql(expression.this, "this") alias = f" AS {self.sql(expression.this, 'alias')}" else: @@ -1188,7 +1205,7 @@ class Generator: if isinstance(arg_value, list): for value in arg_value: args.append(value) - elif arg_value: + else: args.append(arg_value) return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index c4dd91e..c3a23d3 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -2,7 +2,9 @@ import inspect import logging import re import sys +import typing as t from contextlib import contextmanager +from copy import copy from enum import Enum CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") @@ -162,3 +164,54 @@ def find_new_name(taken, base): i += 1 new = f"{base}_{i}" return new + + +def object_to_dict(obj, **kwargs): + return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} + + +def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]: + """ + Perform a split on a value and return N words as a result with None used for words that don't exist. + + Args: + value: The value to be split + sep: The value to use to split on + min_num_words: The minimum number of words that are going to be in the result + fill_from_start: Indicates that if None values should be inserted at the start or end of the list + + Examples: + >>> split_num_words("db.table", ".", 3) + [None, 'db', 'table'] + >>> split_num_words("db.table", ".", 3, fill_from_start=False) + ['db', 'table', None] + >>> split_num_words("db.table", ".", 1) + ['db', 'table'] + """ + words = value.split(sep) + if fill_from_start: + return [None] * (min_num_words - len(words)) + words + return words + [None] * (min_num_words - len(words)) + + +def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: + """ + Flattens a list that can contain both iterables and non-iterable elements + + Examples: + >>> list(flatten([[1, 2], 3])) + [1, 2, 3] + >>> list(flatten([1, 2, 3])) + [1, 2, 3] + + Args: + values: The value to be flattened + + Returns: + Yields non-iterable elements (not including str or byte as iterable) + """ + for value in values: + if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + yield from flatten(value) + else: + yield value diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index d1146ca..bba0878 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,2 +1 @@ from sqlglot.optimizer.optimizer import RULES, optimize -from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index a2cef37..30055bc 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,7 +1,7 @@ from sqlglot import exp from sqlglot.helper import ensure_list, subclasses -from sqlglot.optimizer.schema import ensure_schema from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema def annotate_types(expression, schema=None, annotators=None, coerces_to=None): diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 44cdc94..e30c263 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_union: return _eliminate_union(scope, existing_ctes, taken) - if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)): + if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): return _eliminate_derived_table(scope, existing_ctes, taken) diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index e060739..652cdef 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -12,18 +12,16 @@ def isolate_table_selects(expression): if not isinstance(source, exp.Table): continue - if not isinstance(source.parent, exp.Alias): + if not source.alias: raise OptimizeError("Tables require an alias. Run qualify_tables optimization.") - parent = source.parent - - parent.replace( + source.replace( exp.select("*") .from_( - alias(source, source.name or parent.alias, table=True), + alias(source.copy(), source.name or source.alias, table=True), copy=False, ) - .subquery(parent.alias, copy=False) + .subquery(source.alias, copy=False) ) return expression diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 3c51c18..70e4629 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False): inner_select = inner_scope.expression.unnest() from_or_join = table.find_ancestor(exp.From, exp.Join) if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): - node_to_replace = table - if isinstance(node_to_replace.parent, exp.Alias): - node_to_replace = node_to_replace.parent - alias = node_to_replace.alias - else: - alias = table.name + alias = table.alias_or_name _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, node_to_replace, alias) + _merge_from(outer_scope, inner_scope, table, alias) _merge_expressions(outer_scope, inner_scope, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) @@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): if isinstance(source, exp.Subquery): source.set("alias", exp.TableAlias(this=new_alias)) - elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias): - source.parent.set("alias", new_alias) + elif isinstance(source, exp.Table) and source.alias: + source.set("alias", new_alias) elif isinstance(source, exp.Table): source.replace(exp.alias_(source.copy(), new_alias)) @@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): tables = join_hint.find_all(exp.Table) for table in tables: if table.alias_or_name == node_to_replace.alias_or_name: - new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery - table.set("this", exp.to_identifier(new_table.alias_or_name)) + table.set("this", exp.to_identifier(new_subquery.alias_or_name)) outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 2c28ab8..5ad8f46 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,3 +1,4 @@ +import sqlglot from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries @@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} + If no schema is provided then the default schema defined at `sqlgot.schema` will be used db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement rules (list): sequence of optimizer rules to use @@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar Returns: sqlglot.Expression: optimized expression """ - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs} expression = expression.copy() for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} - expression = rule(expression, **rule_kwargs) return expression diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 5584830..5820851 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() +# SELECTION TO USE IF SELECTION LIST IS EMPTY +DEFAULT_SELECTION = alias("1", "_") + def pushdown_projections(expression): """ @@ -25,7 +28,8 @@ def pushdown_projections(expression): """ # Map of Scope to all columns being selected by outer queries. referenced_columns = defaultdict(set) - + left_union = None + right_union = None # We build the scope tree (which is traversed in DFS postorder), then iterate # over the result in reverse order. This should ensure that the set of selected # columns for a particular scope are completely build by the time we get to it. @@ -37,12 +41,16 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left, right = scope.union_scopes - referenced_columns[left] = parent_selections - referenced_columns[right] = parent_selections + left_union, right_union = scope.union_scopes + referenced_columns[left_union] = parent_selections + referenced_columns[right_union] = parent_selections - if isinstance(scope.expression, exp.Select): - _remove_unused_selections(scope, parent_selections) + if isinstance(scope.expression, exp.Select) and scope != right_union: + removed_indexes = _remove_unused_selections(scope, parent_selections) + # The left union is used for column names to select and if we remove columns from the left + # we need to also remove those same columns in the right that were at the same position + if scope is left_union: + _remove_indexed_selections(right_union, removed_indexes) # Group columns by source name selects = defaultdict(set) @@ -61,6 +69,7 @@ def pushdown_projections(expression): def _remove_unused_selections(scope, parent_selections): + removed_indexes = [] order = scope.expression.args.get("order") if order: @@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections): order_refs = set() new_selections = [] - for selection in scope.selects: + for i, selection in enumerate(scope.selects): if ( SELECT_ALL in parent_selections or selection.alias_or_name in parent_selections or selection.alias_or_name in order_refs ): new_selections.append(selection) + else: + removed_indexes.append(i) # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(alias("1", "_")) + new_selections.append(DEFAULT_SELECTION) + + scope.expression.set("expressions", new_selections) + return removed_indexes + +def _remove_indexed_selections(scope, indexes_to_remove): + new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove] + if not new_selections: + new_selections.append(DEFAULT_SELECTION) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 7d77ef1..36ba028 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -2,8 +2,8 @@ import itertools from sqlglot import alias, exp from sqlglot.errors import OptimizeError -from sqlglot.optimizer.schema import ensure_schema -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema def qualify_columns(expression, schema): @@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables): (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: - if isinstance(derived_table, exp.UDTF): + if isinstance(derived_table.unnest(), exp.UDTF): continue table_alias = derived_table.args.get("alias") if table_alias: @@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver): if column_table: column.set("table", exp.to_identifier(column_table)) + # Determine whether each reference in the order by clause is to a column or an alias. + for ordered in scope.find_all(exp.Ordered): + for column in ordered.find_all(exp.Column): + column_table = column.table + column_name = column.name + + if column_table or column.parent is ordered or column_name not in resolver.all_columns: + continue + + column_table = resolver.get_table(column_name) + + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column_name}") + + column.set("table", exp.to_identifier(column_table)) + def _expand_stars(scope, resolver): """Expand stars to lists of column selections""" @@ -346,6 +362,11 @@ class _Resolver: except Exception as e: raise OptimizeError(str(e)) from e + if isinstance(source, Scope) and isinstance(source.expression, exp.Values): + values_alias = source.expression.parent + if hasattr(values_alias, "alias_column_names"): + return values_alias.alias_column_names + # Otherwise, if referencing another scope, return that scope's named selects return source.expression.named_selects diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 30e93ba..0e467d3 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None): if not source.args.get("catalog"): source.set("catalog", exp.to_identifier(catalog)) - if not isinstance(source.parent, exp.Alias): + if not source.alias: source.replace( alias( source.copy(), diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py deleted file mode 100644 index d7743c9..0000000 --- a/sqlglot/optimizer/schema.py +++ /dev/null @@ -1,180 +0,0 @@ -import abc - -from sqlglot import exp -from sqlglot.errors import OptimizeError -from sqlglot.helper import csv_reader - - -class Schema(abc.ABC): - """Abstract base class for database schemas""" - - @abc.abstractmethod - def column_names(self, table, only_visible=False): - """ - Get the column names for a table. - Args: - table (sqlglot.expressions.Table): Table expression instance - only_visible (bool): Whether to include invisible columns - Returns: - list[str]: list of column names - """ - - @abc.abstractmethod - def get_column_type(self, table, column): - """ - Get the exp.DataType type of a column in the schema. - - Args: - table (sqlglot.expressions.Table): The source table. - column (sqlglot.expressions.Column): The target column. - Returns: - sqlglot.expressions.DataType.Type: The resulting column type. - """ - - -class MappingSchema(Schema): - """ - Schema based on a nested mapping. - - Args: - schema (dict): Mapping in one of the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns - are assumed to be visible. The nesting should mirror that of the schema: - 1. {table: set(*cols)}} - 2. {db: {table: set(*cols)}}} - 3. {catalog: {db: {table: set(*cols)}}}} - dialect (str): The dialect to be used for custom type mappings. - """ - - def __init__(self, schema, visible=None, dialect=None): - self.schema = schema - self.visible = visible - self.dialect = dialect - self._type_mapping_cache = {} - - depth = _dict_depth(schema) - - if not depth: # {} - self.supported_table_args = [] - elif depth == 2: # {table: {col: type}} - self.supported_table_args = ("this",) - elif depth == 3: # {db: {table: {col: type}}} - self.supported_table_args = ("db", "this") - elif depth == 4: # {catalog: {db: {table: {col: type}}}} - self.supported_table_args = ("catalog", "db", "this") - else: - raise OptimizeError(f"Invalid schema shape. Depth: {depth}") - - self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) - - def column_names(self, table, only_visible=False): - if not isinstance(table.this, exp.Identifier): - return fs_get(table) - - args = tuple(table.text(p) for p in self.supported_table_args) - - for forbidden in self.forbidden_args: - if table.text(forbidden): - raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - - columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) - if not only_visible or not self.visible: - return columns - - visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) - return [col for col in columns if col in visible] - - def get_column_type(self, table, column): - try: - schema_type = self.schema.get(table.name, {}).get(column.name).upper() - return self._convert_type(schema_type) - except: - raise OptimizeError(f"Failed to get type for column {column.sql()}") - - def _convert_type(self, schema_type): - """ - Convert a type represented as a string to the corresponding exp.DataType.Type object. - - Args: - schema_type (str): The type we want to convert. - Returns: - sqlglot.expressions.DataType.Type: The resulting expression type. - """ - if schema_type not in self._type_mapping_cache: - try: - self._type_mapping_cache[schema_type] = exp.maybe_parse( - schema_type, into=exp.DataType, dialect=self.dialect - ).this - except AttributeError: - raise OptimizeError(f"Failed to convert type {schema_type}") - - return self._type_mapping_cache[schema_type] - - -def ensure_schema(schema): - if isinstance(schema, Schema): - return schema - - return MappingSchema(schema) - - -def fs_get(table): - name = table.this.name - - if name.upper() == "READ_CSV": - with csv_reader(table) as reader: - return next(reader) - - raise ValueError(f"Cannot read schema for {table}") - - -def _nested_get(d, *path): - """ - Get a value for a nested dictionary. - - Args: - d (dict): dictionary - *path (tuple[str, str]): tuples of (name, key) - `key` is the key in the dictionary to get. - `name` is a string to use in the error if `key` isn't found. - """ - for name, key in path: - d = d.get(key) - if d is None: - name = "table" if name == "this" else name - raise ValueError(f"Unknown {name}") - return d - - -def _dict_depth(d): - """ - Get the nesting depth of a dictionary. - - For example: - >>> _dict_depth(None) - 0 - >>> _dict_depth({}) - 1 - >>> _dict_depth({"a": "b"}) - 1 - >>> _dict_depth({"a": {}}) - 2 - >>> _dict_depth({"a": {"b": {}}}) - 3 - - Args: - d (dict): dictionary - Returns: - int: depth - """ - try: - return 1 + _dict_depth(next(iter(d.values()))) - except AttributeError: - # d doesn't have attribute "values" - return 0 - except StopIteration: - # d.values() returns an empty sequence - return 1 diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 68298a0..b7eb6c2 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -257,12 +257,7 @@ class Scope: referenced_names = [] for table in self.tables: - referenced_names.append( - ( - table.parent.alias if isinstance(table.parent, exp.Alias) else table.name, - table, - ) - ) + referenced_names.append((table.alias_or_name, table)) for derived_table in self.derived_tables: referenced_names.append((derived_table.alias, derived_table.unnest())) @@ -538,8 +533,8 @@ def _add_table_sources(scope): for table in scope.tables: table_name = table.name - if isinstance(table.parent, exp.Alias): - source_name = table.parent.alias + if table.alias: + source_name = table.alias else: source_name = table_name diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b378f12..47c1c1d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -329,6 +329,7 @@ class Parser: exp.DataType: lambda self: self._parse_types(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), + exp.Identifier: lambda self: self._parse_id_var(), exp.Lateral: lambda self: self._parse_lateral(), exp.Join: lambda self: self._parse_join(), exp.Order: lambda self: self._parse_order(), @@ -371,11 +372,8 @@ class Parser: TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), - TokenType.INTRODUCER: lambda self, token: self.expression( - exp.Introducer, - this=token.text, - expression=self._parse_var_or_string(), - ), + TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text), + TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), } RANGE_PARSERS = { @@ -500,7 +498,7 @@ class Parser: max_errors=3, null_ordering=None, ): - self.error_level = error_level or ErrorLevel.RAISE + self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context self.index_offset = index_offset self.unnest_column_only = unnest_column_only @@ -928,6 +926,7 @@ class Parser: return self.expression( exp.Delete, this=self._parse_table(schema=True), + using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)), where=self._parse_where(), ) @@ -1148,7 +1147,7 @@ class Parser: def _parse_annotation(self, expression): if self._match(TokenType.ANNOTATION): - return self.expression(exp.Annotation, this=self._prev.text, expression=expression) + return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression) return expression @@ -1277,7 +1276,7 @@ class Parser: alias = self._parse_table_alias() if alias: - this = self.expression(exp.Alias, this=this, alias=alias) + this.set("alias", alias) if not self.alias_post_tablesample: table_sample = self._parse_table_sample() @@ -1876,6 +1875,17 @@ class Parser: self._match_r_paren() return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + def _parse_introducer(self, token): + literal = self._parse_primary() + if literal: + return self.expression( + exp.Introducer, + this=token.text, + expression=literal, + ) + + return self.expression(exp.Identifier, this=token.text) + def _parse_udf_kwarg(self): this = self._parse_id_var() kind = self._parse_types() diff --git a/sqlglot/planner.py b/sqlglot/planner.py index efabc15..ea995d8 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -199,13 +199,14 @@ class Step: class Scan(Step): @classmethod def from_expression(cls, expression, ctes=None): - table = expression.this + table = expression alias_ = expression.alias if not alias_: raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer") if isinstance(expression, exp.Subquery): + table = expression.this step = Step.from_expression(table, ctes) step.name = alias_ return step diff --git a/sqlglot/schema.py b/sqlglot/schema.py new file mode 100644 index 0000000..c916330 --- /dev/null +++ b/sqlglot/schema.py @@ -0,0 +1,297 @@ +import abc + +from sqlglot import expressions as exp +from sqlglot.errors import OptimizeError +from sqlglot.helper import csv_reader + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @abc.abstractmethod + def add_table(self, table, column_mapping=None): + """ + Register or update a table. Some implementing classes may require column information to also be provided + + Args: + table (sqlglot.expressions.Table|str): Table expression instance or string representing the table + column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + """ + + @abc.abstractmethod + def column_names(self, table, only_visible=False): + """ + Get the column names for a table. + Args: + table (sqlglot.expressions.Table): Table expression instance + only_visible (bool): Whether to include invisible columns + Returns: + list[str]: list of column names + """ + + @abc.abstractmethod + def get_column_type(self, table, column): + """ + Get the exp.DataType type of a column in the schema. + + Args: + table (sqlglot.expressions.Table): The source table. + column (sqlglot.expressions.Column): The target column. + Returns: + sqlglot.expressions.DataType.Type: The resulting column type. + """ + + +class MappingSchema(Schema): + """ + Schema based on a nested mapping. + + Args: + schema (dict): Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + 4. None - Tables will be added later + visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect (str): The dialect to be used for custom type mappings. + """ + + def __init__(self, schema=None, visible=None, dialect=None): + self.schema = schema or {} + self.visible = visible + self.dialect = dialect + self._type_mapping_cache = {} + self.supported_table_args = [] + self.forbidden_table_args = set() + if self.schema: + self._initialize_supported_args() + + @classmethod + def from_mapping_schema(cls, mapping_schema): + return MappingSchema( + schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect + ) + + def copy(self, **kwargs): + return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) + + def add_table(self, table, column_mapping=None): + """ + Register or update a table. Updates are only performed if a new column mapping is provided. + + Args: + table (sqlglot.expressions.Table|str): Table expression instance or string representing the table + column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table + """ + table = exp.to_table(table) + self._validate_table(table) + column_mapping = ensure_column_mapping(column_mapping) + table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] + existing_column_mapping = _nested_get( + self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False + ) + if existing_column_mapping and not column_mapping: + return + _nested_set( + self.schema, + [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], + column_mapping, + ) + self._initialize_supported_args() + + def _get_table_args_from_table(self, table): + if table.args.get("catalog") is not None: + return "catalog", "db", "this" + if table.args.get("db") is not None: + return "db", "this" + return ("this",) + + def _validate_table(self, table): + if not self.supported_table_args and isinstance(table, exp.Table): + return + for forbidden in self.forbidden_table_args: + if table.text(forbidden): + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + for expected in self.supported_table_args: + if not table.text(expected): + raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ") + + def column_names(self, table, only_visible=False): + table = exp.to_table(table) + if not isinstance(table.this, exp.Identifier): + return fs_get(table) + + args = tuple(table.text(p) for p in self.supported_table_args) + + for forbidden in self.forbidden_table_args: + if table.text(forbidden): + raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") + + columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + if not only_visible or not self.visible: + return columns + + visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) + return [col for col in columns if col in visible] + + def get_column_type(self, table, column): + try: + schema_type = self.schema.get(table.name, {}).get(column.name).upper() + return self._convert_type(schema_type) + except: + raise OptimizeError(f"Failed to get type for column {column.sql()}") + + def _convert_type(self, schema_type): + """ + Convert a type represented as a string to the corresponding exp.DataType.Type object. + Args: + schema_type (str): The type we want to convert. + Returns: + sqlglot.expressions.DataType.Type: The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + try: + self._type_mapping_cache[schema_type] = exp.maybe_parse( + schema_type, into=exp.DataType, dialect=self.dialect + ).this + except AttributeError: + raise OptimizeError(f"Failed to convert type {schema_type}") + + return self._type_mapping_cache[schema_type] + + def _initialize_supported_args(self): + if not self.supported_table_args: + depth = _dict_depth(self.schema) + + all_args = ["this", "db", "catalog"] + if not depth or depth == 1: # {} + self.supported_table_args = [] + elif 2 <= depth <= 4: + self.supported_table_args = tuple(reversed(all_args[: depth - 1])) + else: + raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + + self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args) + + +def ensure_schema(schema): + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema) + + +def ensure_column_mapping(mapping): + if isinstance(mapping, dict): + return mapping + elif isinstance(mapping, str): + col_name_type_strs = [x.strip() for x in mapping.split(",")] + return { + name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() + for name_type_str in col_name_type_strs + } + # Check if mapping looks like a DataFrame StructType + elif hasattr(mapping, "simpleString"): + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} + elif isinstance(mapping, list): + return {x.strip(): None for x in mapping} + elif mapping is None: + return {} + raise ValueError(f"Invalid mapping provided: {type(mapping)}") + + +def fs_get(table): + name = table.this.name + + if name.upper() == "READ_CSV": + with csv_reader(table) as reader: + return next(reader) + + raise ValueError(f"Cannot read schema for {table}") + + +def _nested_get(d, *path, raise_on_missing=True): + """ + Get a value for a nested dictionary. + + Args: + d (dict): dictionary + *path (tuple[str, str]): tuples of (name, key) + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + + Returns: + The value or None if it doesn't exist + """ + for name, key in path: + d = d.get(key) + if d is None: + if raise_on_missing: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}") + return None + return d + + +def _nested_set(d, keys, value): + """ + In-place set a value for a nested dictionary + + Ex: + >>> _nested_set({}, ["top_key", "second_key"], "value") + {'top_key': {'second_key': 'value'}} + >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") + {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} + + d (dict): dictionary + keys (Iterable[str]): ordered iterable of keys that makeup path to value + value (Any): The value to set in the dictionary for the given key path + """ + if not keys: + return + if len(keys) == 1: + d[keys[0]] = value + return + subd = d + for key in keys[:-1]: + if key not in subd: + subd = subd.setdefault(key, {}) + else: + subd = subd[key] + subd[keys[-1]] = value + return d + + +def _dict_depth(d): + """ + Get the nesting depth of a dictionary. + + For example: + >>> _dict_depth(None) + 0 + >>> _dict_depth({}) + 1 + >>> _dict_depth({"a": "b"}) + 1 + >>> _dict_depth({"a": {}}) + 2 + >>> _dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + _dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index fc8e6e7..1a9d72e 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -56,6 +56,7 @@ class TokenType(AutoName): VAR = auto() BIT_STRING = auto() HEX_STRING = auto() + BYTE_STRING = auto() # types BOOLEAN = auto() @@ -320,6 +321,7 @@ class _Tokenizer(type): klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) + klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) klass._COMMENTS = dict( (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS @@ -333,6 +335,7 @@ class _Tokenizer(type): **{quote: TokenType.QUOTE for quote in klass._QUOTES}, **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, + **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS}, }.items() if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) @@ -385,6 +388,8 @@ class Tokenizer(metaclass=_Tokenizer): HEX_STRINGS = [] + BYTE_STRINGS = [] + IDENTIFIERS = ['"'] ESCAPE = "'" @@ -799,7 +804,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._scan_string(word): return - if self._scan_numeric_string(word): + if self._scan_formatted_string(word): return if self._scan_comment(word): return @@ -906,7 +911,8 @@ class Tokenizer(metaclass=_Tokenizer): self._add(TokenType.STRING, text) return True - def _scan_numeric_string(self, string_start): + # X'1234, b'0110', E'\\\\\' etc. + def _scan_formatted_string(self, string_start): if string_start in self._HEX_STRINGS: delimiters = self._HEX_STRINGS token_type = TokenType.HEX_STRING @@ -915,6 +921,10 @@ class Tokenizer(metaclass=_Tokenizer): delimiters = self._BIT_STRINGS token_type = TokenType.BIT_STRING base = 2 + elif string_start in self._BYTE_STRINGS: + delimiters = self._BYTE_STRINGS + token_type = TokenType.BYTE_STRING + base = None else: return False @@ -922,10 +932,14 @@ class Tokenizer(metaclass=_Tokenizer): string_end = delimiters.get(string_start) text = self._extract_string(string_end) - try: - self._add(token_type, f"{int(text, base)}") - except ValueError: - raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + if base is None: + self._add(token_type, text) + else: + try: + self._add(token_type, f"{int(text, base)}") + except: + raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") + return True def _scan_identifier(self, identifier_end): diff --git a/tests/dataframe/__init__.py b/tests/dataframe/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/dataframe/__init__.py diff --git a/tests/dataframe/integration/__init__.py b/tests/dataframe/integration/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/dataframe/integration/__init__.py diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py new file mode 100644 index 0000000..6c4642f --- /dev/null +++ b/tests/dataframe/integration/dataframe_validator.py @@ -0,0 +1,149 @@ +import typing as t +import unittest +import warnings + +import sqlglot +from tests.helpers import SKIP_INTEGRATION + +if t.TYPE_CHECKING: + from pyspark.sql import DataFrame as SparkDataFrame + + +@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") +class DataFrameValidator(unittest.TestCase): + spark = None + sqlglot = None + df_employee = None + df_store = None + df_district = None + spark_employee_schema = None + sqlglot_employee_schema = None + spark_store_schema = None + sqlglot_store_schema = None + spark_district_schema = None + sqlglot_district_schema = None + + @classmethod + def setUpClass(cls): + from pyspark import SparkConf + from pyspark.sql import SparkSession, types + + from sqlglot.dataframe.sql import types as sqlglotSparkTypes + from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession + + # This is for test `test_branching_root_dataframes` + config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) + cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark.sparkContext.setLogLevel("ERROR") + cls.sqlglot = SqlglotSparkSession() + cls.spark_employee_schema = types.StructType( + [ + types.StructField("employee_id", types.IntegerType(), False), + types.StructField("fname", types.StringType(), False), + types.StructField("lname", types.StringType(), False), + types.StructField("age", types.IntegerType(), False), + types.StructField("store_id", types.IntegerType(), False), + ] + ) + cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( + [ + sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), + ] + ) + employee_data = [ + (1, "Jack", "Shephard", 37, 1), + (2, "John", "Locke", 65, 1), + (3, "Kate", "Austen", 37, 2), + (4, "Claire", "Littleton", 27, 2), + (5, "Hugo", "Reyes", 29, 100), + ] + cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) + cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee.createOrReplaceTempView("employee") + + cls.spark_store_schema = types.StructType( + [ + types.StructField("store_id", types.IntegerType(), False), + types.StructField("store_name", types.StringType(), False), + types.StructField("district_id", types.IntegerType(), False), + types.StructField("num_sales", types.IntegerType(), False), + ] + ) + cls.sqlglot_store_schema = sqlglotSparkTypes.StructType( + [ + sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), + ] + ) + store_data = [ + (1, "Hydra", 1, 37), + (2, "Arrow", 2, 2000), + ] + cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.df_store.createOrReplaceTempView("store") + + cls.spark_district_schema = types.StructType( + [ + types.StructField("district_id", types.IntegerType(), False), + types.StructField("district_name", types.StringType(), False), + types.StructField("manager_name", types.StringType(), False), + ] + ) + cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( + [ + sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + ] + ) + district_data = [ + (1, "Temple", "Dogen"), + (2, "Lighthouse", "Jacob"), + ] + cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) + cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district.createOrReplaceTempView("district") + sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) + sqlglot.schema.add_table("store", cls.sqlglot_store_schema) + sqlglot.schema.add_table("district", cls.sqlglot_district_schema) + + def setUp(self) -> None: + warnings.filterwarnings("ignore", category=ResourceWarning) + self.df_spark_store = self.df_store.alias("df_store") # type: ignore + self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore + self.df_spark_district = self.df_district.alias("df_district") # type: ignore + self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore + self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore + self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore + + def compare_spark_with_sqlglot( + self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False + ) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]: + def compare_schemas(schema_1, schema_2): + for schema in [schema_1, schema_2]: + for struct_field in schema.fields: + struct_field.metadata = {} + self.assertEqual(schema_1, schema_2) + + for statement in df_sqlglot.sql(): + actual_df_sqlglot = self.spark.sql(statement) # type: ignore + df_sqlglot_results = actual_df_sqlglot.collect() + df_spark_results = df_spark.collect() + if not skip_schema_compare: + compare_schemas(df_spark.schema, actual_df_sqlglot.schema) + self.assertEqual(df_spark_results, df_sqlglot_results) + if no_empty: + self.assertNotEqual(len(df_spark_results), 0) + self.assertNotEqual(len(df_sqlglot_results), 0) + return df_spark, actual_df_sqlglot + + @classmethod + def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str: + return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py new file mode 100644 index 0000000..c740bec --- /dev/null +++ b/tests/dataframe/integration/test_dataframe.py @@ -0,0 +1,1103 @@ +from pyspark.sql import functions as F + +from sqlglot.dataframe.sql import functions as SF +from tests.dataframe.integration.dataframe_validator import DataFrameValidator + + +class TestDataframeFunc(DataFrameValidator): + def test_simple_select(self): + df_employee = self.df_spark_employee.select(F.col("employee_id")) + dfs_employee = self.df_sqlglot_employee.select(SF.col("employee_id")) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_simple_select_from_table(self): + df = self.df_spark_employee + dfs = self.sqlglot.read.table("employee") + self.compare_spark_with_sqlglot(df, dfs) + + def test_simple_select_df_attribute(self): + df_employee = self.df_spark_employee.select(self.df_spark_employee.employee_id) + dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee.employee_id) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_simple_select_df_dict(self): + df_employee = self.df_spark_employee.select(self.df_spark_employee["employee_id"]) + dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee["employee_id"]) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_multiple_selects(self): + df_employee = self.df_spark_employee.select( + self.df_spark_employee["employee_id"], F.col("fname"), self.df_spark_employee.lname + ) + dfs_employee = self.df_sqlglot_employee.select( + self.df_sqlglot_employee["employee_id"], SF.col("fname"), self.df_sqlglot_employee.lname + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_alias_no_op(self): + df_employee = self.df_spark_employee.alias("df_employee") + dfs_employee = self.df_sqlglot_employee.alias("dfs_employee") + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_alias_with_select(self): + df_employee = self.df_spark_employee.alias("df_employee").select( + self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + ) + dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( + self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_case_when_otherwise(self): + df = self.df_spark_employee.select( + F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + .when(F.col("age") < F.lit(40), "less than 40") + .otherwise("greater than 60") + ) + + dfs = self.df_sqlglot_employee.select( + SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + .when(SF.col("age") < SF.lit(40), "less than 40") + .otherwise("greater than 60") + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_case_when_no_otherwise(self): + df = self.df_spark_employee.select( + F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( + F.col("age") < F.lit(40), "less than 40" + ) + ) + + dfs = self.df_sqlglot_employee.select( + SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( + SF.col("age") < SF.lit(40), "less than 40" + ) + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_where_clause_single(self): + df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_clause_multiple_and(self): + df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_many_and(self): + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) + & (F.col("fname") == F.lit("Jack")) + & (F.col("lname") == F.lit("Shephard")) + & (F.col("employee_id") == F.lit(1)) + ) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) + & (SF.col("fname") == SF.lit("Jack")) + & (SF.col("lname") == SF.lit("Shephard")) + & (SF.col("employee_id") == SF.lit(1)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_clause_multiple_or(self): + df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_many_or(self): + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) + | (F.col("fname") == F.lit("Kate")) + | (F.col("lname") == F.lit("Littleton")) + | (F.col("employee_id") == F.lit(2)) + ) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) + | (SF.col("fname") == SF.lit("Kate")) + | (SF.col("lname") == SF.lit("Littleton")) + | (SF.col("employee_id") == SF.lit(2)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_mixed_and_or(self): + df_employee = self.df_spark_employee.where( + ((F.col("age") == F.lit(65)) & (F.col("fname") == F.lit("John"))) + | ((F.col("lname") == F.lit("Shephard")) & (F.col("age") == F.lit(37))) + ) + dfs_employee = self.df_sqlglot_employee.where( + ((SF.col("age") == SF.lit(65)) & (SF.col("fname") == SF.lit("John"))) + | ((SF.col("lname") == SF.lit("Shephard")) & (SF.col("age") == SF.lit(37))) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_multiple_chained(self): + df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)).where( + self.df_spark_employee.fname == F.lit("Jack") + ) + dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)).where( + self.df_sqlglot_employee.fname == SF.lit("Jack") + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_operators(self): + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] < F.lit(50)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] < SF.lit(50)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] <= F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] <= SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] > F.lit(50)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] > SF.lit(50)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] >= F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] >= SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] != F.lit(50)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] != SF.lit(50)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] == F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_join_inner(self): + df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + F.col("store_id"), + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + SF.col("store_id"), + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_no_select(self): + df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + ) + dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_equality_single(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_equality_multiple(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, + on=[ + self.df_spark_employee.store_id == self.df_spark_store.store_id, + self.df_spark_employee.age == self.df_spark_store.num_sales, + ], + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, + on=[ + self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales, + ], + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_equality_multiple_bitwise_and(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, + on=(self.df_spark_employee.store_id == self.df_spark_store.store_id) + & (self.df_spark_employee.age == self.df_spark_store.num_sales), + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, + on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id) + & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales), + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_left_outer(self): + df_joined = ( + self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="left_outer") + .select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + F.col("store_id"), + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + .orderBy(F.col("employee_id")) + ) + dfs_joined = ( + self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="left_outer") + .select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + SF.col("store_id"), + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_full_outer(self): + df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + F.col("store_id"), + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + SF.col("store_id"), + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_triple_join(self): + df = ( + self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) + .select( + self.df_employee.employee_id, + self.df_store.store_id, + self.df_district.district_id, + self.df_employee.fname, + self.df_store.store_name, + self.df_district.district_name, + ) + ) + dfs = ( + self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) + .select( + self.dfs_employee.employee_id, + self.dfs_store.store_id, + self.dfs_district.district_id, + self.dfs_employee.fname, + self.dfs_store.store_name, + self.dfs_district.district_name, + ) + ) + self.compare_spark_with_sqlglot(df, dfs) + + def test_join_select_and_select_start(self): + df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( + self.df_spark_store, "store_id", "inner" + ) + + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( + self.df_sqlglot_store, "store_id", "inner" + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_branching_root_dataframes(self): + """ + Test a pattern that has non-intuitive behavior in spark + + Scenario: You do a self-join in a dataframe using an original dataframe and then a modified version + of it. You then reference the columns by the dataframe name instead of the column function. + Spark will use the root dataframe's column in the result. + """ + df_hydra_employees_only = self.df_spark_employee.where(F.col("store_id") == F.lit(1)) + df_joined = ( + self.df_spark_employee.where(F.col("store_id") == F.lit(2)) + .alias("df_arrow_employees_only") + .join( + df_hydra_employees_only.alias("df_hydra_employees_only"), + on=["store_id"], + how="full_outer", + ) + .select( + self.df_spark_employee.fname, + F.col("df_arrow_employees_only.fname"), + df_hydra_employees_only.fname, + F.col("df_hydra_employees_only.fname"), + ) + ) + + dfs_hydra_employees_only = self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(1)) + dfs_joined = ( + self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(2)) + .alias("dfs_arrow_employees_only") + .join( + dfs_hydra_employees_only.alias("dfs_hydra_employees_only"), + on=["store_id"], + how="full_outer", + ) + .select( + self.df_sqlglot_employee.fname, + SF.col("dfs_arrow_employees_only.fname"), + dfs_hydra_employees_only.fname, + SF.col("dfs_hydra_employees_only.fname"), + ) + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_basic_union(self): + df_unioned = self.df_spark_employee.select(F.col("employee_id"), F.col("age")).union( + self.df_spark_store.select(F.col("store_id"), F.col("num_sales")) + ) + + dfs_unioned = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("num_sales")) + ) + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_union_with_join(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, + on="store_id", + how="inner", + ) + df_unioned = df_joined.select(F.col("store_id"), F.col("store_name")).union( + self.df_spark_district.select(F.col("district_id"), F.col("district_name")) + ) + + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, + on="store_id", + how="inner", + ) + dfs_unioned = dfs_joined.select(SF.col("store_id"), SF.col("store_name")).union( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) + + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_double_union_all(self): + df_unioned = ( + self.df_spark_employee.select(F.col("employee_id"), F.col("fname")) + .unionAll(self.df_spark_store.select(F.col("store_id"), F.col("store_name"))) + .unionAll(self.df_spark_district.select(F.col("district_id"), F.col("district_name"))) + ) + + dfs_unioned = ( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) + .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) + .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + ) + + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_union_by_name(self): + df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + self.df_spark_store.select( + F.col("store_name").alias("lname"), + F.col("store_id").alias("employee_id"), + F.col("store_name").alias("fname"), + ) + ) + + dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + self.df_sqlglot_store.select( + SF.col("store_name").alias("lname"), + SF.col("store_id").alias("employee_id"), + SF.col("store_name").alias("fname"), + ) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_union_by_name_allow_missing(self): + df = self.df_spark_employee.select( + F.col("age"), F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( + self.df_spark_store.select( + F.col("store_name").alias("lname"), + F.col("store_id").alias("employee_id"), + F.col("store_name").alias("fname"), + F.col("num_sales"), + ), + allowMissingColumns=True, + ) + + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( + self.df_sqlglot_store.select( + SF.col("store_name").alias("lname"), + SF.col("store_id").alias("employee_id"), + SF.col("store_name").alias("fname"), + SF.col("num_sales"), + ), + allowMissingColumns=True, + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_default(self): + df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_order_by_array_bool(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.col("total_sales"), F.col("district_id"), ascending=[1, 0]) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=[1, 0]) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_single_bool(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.col("total_sales"), F.col("district_id"), ascending=False) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=False) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_column_sort_method(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.col("total_sales").asc(), F.col("district_id").desc()) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.col("total_sales").asc(), SF.col("district_id").desc()) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_column_sort_method_nulls_last(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_column_sort_method_nulls_first(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_intersect(self): + df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( + self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) + ) + + df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( + self.df_spark_store.select(F.col("store_id"), F.col("district_id")) + ) + + df = df_employee_duplicate.intersect(df_store_duplicate) + + dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) + ) + + dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) + ) + + dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_intersect_all(self): + df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( + self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) + ) + + df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( + self.df_spark_store.select(F.col("store_id"), F.col("district_id")) + ) + + df = df_employee_duplicate.intersectAll(df_store_duplicate) + + dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) + ) + + dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) + ) + + dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_except_all(self): + df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( + self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) + ) + + df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( + self.df_spark_store.select(F.col("store_id"), F.col("district_id")) + ) + + df = df_employee_duplicate.exceptAll(df_store_duplicate) + + dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) + ) + + dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) + ) + + dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_distinct(self): + df = self.df_spark_employee.select(F.col("age")).distinct() + + dfs = self.df_sqlglot_employee.select(SF.col("age")).distinct() + + self.compare_spark_with_sqlglot(df, dfs) + + def test_union_distinct(self): + df_unioned = ( + self.df_spark_employee.select(F.col("employee_id"), F.col("age")) + .union(self.df_spark_employee.select(F.col("employee_id"), F.col("age"))) + .distinct() + ) + + dfs_unioned = ( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")) + .union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age"))) + .distinct() + ) + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_drop_duplicates_no_subset(self): + df = self.df_spark_employee.select("age").dropDuplicates() + dfs = self.df_sqlglot_employee.select("age").dropDuplicates() + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_duplicates_subset(self): + df = self.df_spark_employee.dropDuplicates(["age"]) + dfs = self.df_sqlglot_employee.dropDuplicates(["age"]) + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_na_default(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna() + + self.compare_spark_with_sqlglot(df, dfs) + + def test_dropna_how(self): + df = self.df_spark_employee.select( + F.lit(None), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna(how="all") + + dfs = self.df_sqlglot_employee.select( + SF.lit(None), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna(how="all") + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_dropna_thresh(self): + df = self.df_spark_employee.select( + F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna(how="any", thresh=2) + + dfs = self.df_sqlglot_employee.select( + SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna(how="any", thresh=2) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_dropna_subset(self): + df = self.df_spark_employee.select( + F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna(thresh=1, subset="the_age") + + dfs = self.df_sqlglot_employee.select( + SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna(thresh=1, subset="the_age") + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_dropna_na_function(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).na.drop() + + self.compare_spark_with_sqlglot(df, dfs) + + def test_fillna_default(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).fillna(100) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_fillna_dict_replacement(self): + df = self.df_spark_employee.select( + F.col("fname"), + F.when(F.col("lname").startswith("L"), F.col("lname")).alias("l_lname"), + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age"), + ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"}) + + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), + SF.when(SF.col("lname").startswith("L"), SF.col("lname")).alias("l_lname"), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), + ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"}) + + # For some reason the sqlglot results sets a column as nullable when it doesn't need to + # This seems to be a nuance in how spark dataframe from sql works so we can ignore + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_fillna_na_func(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).na.fill(100) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_replace_basic(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_basic_subset(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100, subset="age" + ) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + to_replace=37, value=100, subset="age" + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_mapping(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_mapping_subset(self): + df = self.df_spark_employee.select( + F.col("age"), F.lit(37).alias("test_col"), F.lit(50).alias("test_col_2") + ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"]) + + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col"), SF.lit(50).alias("test_col_2") + ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"]) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_na_func_basic(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).na.replace( + to_replace=37, value=100 + ) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( + to_replace=37, value=100 + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_with_column(self): + df = self.df_spark_employee.withColumn("test", F.col("age")) + + dfs = self.df_sqlglot_employee.withColumn("test", SF.col("age")) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_with_column_existing_name(self): + df = self.df_spark_employee.withColumn("fname", F.lit("blah")) + + dfs = self.df_sqlglot_employee.withColumn("fname", SF.lit("blah")) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_with_column_renamed(self): + df = self.df_spark_employee.withColumnRenamed("fname", "first_name") + + dfs = self.df_sqlglot_employee.withColumnRenamed("fname", "first_name") + + self.compare_spark_with_sqlglot(df, dfs) + + def test_with_column_renamed_double(self): + df = self.df_spark_employee.select(F.col("fname").alias("first_name")).withColumnRenamed( + "first_name", "first_name_again" + ) + + dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( + "first_name", "first_name_again" + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_column_single(self): + df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") + + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_column_reference_join(self): + df_spark_employee_cols = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ) + df_spark_store_cols = self.df_spark_store.select(F.col("store_id"), F.col("store_name")) + df = df_spark_employee_cols.join(df_spark_store_cols, on="store_id", how="inner").drop( + df_spark_employee_cols.age, + ) + + df_sqlglot_employee_cols = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ) + df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( + df_sqlglot_employee_cols.age, + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_limit(self): + df = self.df_spark_employee.limit(1) + + dfs = self.df_sqlglot_employee.limit(1) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_hint_broadcast_alias(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store.alias("store").hint("broadcast", "store"), + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store.alias("store").hint("broadcast", "store"), + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) + + def test_hint_broadcast_no_alias(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store.hint("broadcast"), + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store.hint("broadcast"), + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) + + # TODO: Add test to make sure with and without alias are the same once ids are deterministic + + def test_broadcast_func(self): + df_joined = self.df_spark_employee.join( + F.broadcast(self.df_spark_store), + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + SF.broadcast(self.df_sqlglot_store), + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) + + def test_repartition_by_num(self): + """ + The results are different when doing the repartition on a table created using VALUES in SQL. + So I just use the views instead for these tests + """ + df = self.df_spark_employee.repartition(63) + + dfs = self.sqlglot.read.table("employee").repartition(63) + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + spark_num_partitions = df.rdd.getNumPartitions() + sqlglot_num_partitions = dfs.rdd.getNumPartitions() + self.assertEqual(spark_num_partitions, 63) + self.assertEqual(spark_num_partitions, sqlglot_num_partitions) + + def test_repartition_name_only(self): + """ + We use the view here to help ensure the explain plans are similar enough to compare + """ + df = self.df_spark_employee.repartition("age") + + dfs = self.sqlglot.read.table("employee").repartition("age") + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + self.assertIn("RepartitionByExpression [age", self.get_explain_plan(df)) + self.assertIn("RepartitionByExpression [age", self.get_explain_plan(dfs)) + + def test_repartition_num_and_multiple_names(self): + """ + We use the view here to help ensure the explain plans are similar enough to compare + """ + df = self.df_spark_employee.repartition(53, "age", "fname") + + dfs = self.sqlglot.read.table("employee").repartition(53, "age", "fname") + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + spark_num_partitions = df.rdd.getNumPartitions() + sqlglot_num_partitions = dfs.rdd.getNumPartitions() + self.assertEqual(spark_num_partitions, 53) + self.assertEqual(spark_num_partitions, sqlglot_num_partitions) + self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(df)) + self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(dfs)) + + def test_coalesce(self): + df = self.df_spark_employee.coalesce(1) + dfs = self.df_sqlglot_employee.coalesce(1) + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + spark_num_partitions = df.rdd.getNumPartitions() + sqlglot_num_partitions = dfs.rdd.getNumPartitions() + self.assertEqual(spark_num_partitions, 1) + self.assertEqual(spark_num_partitions, sqlglot_num_partitions) + + def test_cache_select(self): + df_employee = ( + self.df_spark_employee.groupBy("store_id") + .agg(F.countDistinct("employee_id").alias("num_employees")) + .cache() + ) + df_joined = df_employee.join(self.df_spark_store, on="store_id").select( + self.df_spark_store.store_id, df_employee.num_employees + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy("store_id") + .agg(SF.countDistinct("employee_id").alias("num_employees")) + .cache() + ) + dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select( + self.df_sqlglot_store.store_id, dfs_employee.num_employees + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_persist_select(self): + df_employee = ( + self.df_spark_employee.groupBy("store_id") + .agg(F.countDistinct("employee_id").alias("num_employees")) + .persist() + ) + df_joined = df_employee.join(self.df_spark_store, on="store_id").select( + self.df_spark_store.store_id, df_employee.num_employees + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy("store_id") + .agg(SF.countDistinct("employee_id").alias("num_employees")) + .persist() + ) + dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select( + self.df_sqlglot_store.store_id, dfs_employee.num_employees + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py new file mode 100644 index 0000000..2768dda --- /dev/null +++ b/tests/dataframe/integration/test_grouped_data.py @@ -0,0 +1,71 @@ +from pyspark.sql import functions as F + +from sqlglot.dataframe.sql import functions as SF +from tests.dataframe.integration.dataframe_validator import DataFrameValidator + + +class TestDataframeFunc(DataFrameValidator): + def test_group_by(self): + df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg( + F.min(self.df_spark_employee.employee_id) + ) + dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg( + SF.min(self.df_sqlglot_employee.employee_id) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True) + + def test_group_by_where_non_aggregate(self): + df_employee = ( + self.df_spark_employee.groupBy(self.df_spark_employee.age) + .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id")) + .where(F.col("age") > F.lit(50)) + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age) + .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id")) + .where(SF.col("age") > SF.lit(50)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_group_by_where_aggregate_like_having(self): + df_employee = ( + self.df_spark_employee.groupBy(self.df_spark_employee.age) + .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id")) + .where(F.col("min_employee_id") > F.lit(1)) + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age) + .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id")) + .where(SF.col("min_employee_id") > SF.lit(1)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_count(self): + df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count() + dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count() + self.compare_spark_with_sqlglot(df, dfs) + + def test_mean(self): + df = self.df_spark_employee.groupBy().mean("age", "store_id") + dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_avg(self): + df = self.df_spark_employee.groupBy("age").avg("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_max(self): + df = self.df_spark_employee.groupBy("age").max("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").max("store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_min(self): + df = self.df_spark_employee.groupBy("age").min("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").min("store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_sum(self): + df = self.df_spark_employee.groupBy("age").sum("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id") + self.compare_spark_with_sqlglot(df, dfs) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py new file mode 100644 index 0000000..ff1477b --- /dev/null +++ b/tests/dataframe/integration/test_session.py @@ -0,0 +1,28 @@ +from pyspark.sql import functions as F + +from sqlglot.dataframe.sql import functions as SF +from tests.dataframe.integration.dataframe_validator import DataFrameValidator + + +class TestSessionFunc(DataFrameValidator): + def test_sql_simple_select(self): + query = "SELECT fname, lname FROM employee" + df = self.spark.sql(query) + dfs = self.sqlglot.sql(query) + self.compare_spark_with_sqlglot(df, dfs) + + def test_sql_with_join(self): + query = """ + SELECT + e.employee_id + , s.store_id + FROM + employee e + INNER JOIN + store s + ON + e.store_id = s.store_id + """ + df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) + dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) diff --git a/tests/dataframe/unit/__init__.py b/tests/dataframe/unit/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/dataframe/unit/__init__.py diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py new file mode 100644 index 0000000..fc56553 --- /dev/null +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -0,0 +1,35 @@ +import typing as t +import unittest + +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.dataframe.sql.session import SparkSession + + +class DataFrameSQLValidator(unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession() + self.employee_schema = types.StructType( + [ + types.StructField("employee_id", types.IntegerType(), False), + types.StructField("fname", types.StringType(), False), + types.StructField("lname", types.StringType(), False), + types.StructField("age", types.IntegerType(), False), + types.StructField("store_id", types.IntegerType(), False), + ] + ) + employee_data = [ + (1, "Jack", "Shephard", 37, 1), + (2, "John", "Locke", 65, 1), + (3, "Kate", "Austen", 37, 2), + (4, "Claire", "Littleton", 27, 2), + (5, "Hugo", "Reyes", 29, 100), + ] + self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + + def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + actual_sqls = df.sql(pretty=pretty) + expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + self.assertEqual(len(expected_statements), len(actual_sqls)) + for expected, actual in zip(expected_statements, actual_sqls): + self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py new file mode 100644 index 0000000..df0ebff --- /dev/null +++ b/tests/dataframe/unit/test_column.py @@ -0,0 +1,167 @@ +import datetime +import unittest + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.window import Window + + +class TestDataframeColumn(unittest.TestCase): + def test_eq(self): + self.assertEqual("cola = 1", (F.col("cola") == 1).sql()) + + def test_neq(self): + self.assertEqual("cola <> 1", (F.col("cola") != 1).sql()) + + def test_gt(self): + self.assertEqual("cola > 1", (F.col("cola") > 1).sql()) + + def test_lt(self): + self.assertEqual("cola < 1", (F.col("cola") < 1).sql()) + + def test_le(self): + self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql()) + + def test_ge(self): + self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql()) + + def test_and(self): + self.assertEqual( + "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() + ) + + def test_or(self): + self.assertEqual( + "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() + ) + + def test_mod(self): + self.assertEqual("cola % 2", (F.col("cola") % 2).sql()) + + def test_add(self): + self.assertEqual("cola + 1", (F.col("cola") + 1).sql()) + + def test_sub(self): + self.assertEqual("cola - 1", (F.col("cola") - 1).sql()) + + def test_mul(self): + self.assertEqual("cola * 2", (F.col("cola") * 2).sql()) + + def test_div(self): + self.assertEqual("cola / 2", (F.col("cola") / 2).sql()) + + def test_radd(self): + self.assertEqual("1 + cola", (1 + F.col("cola")).sql()) + + def test_rsub(self): + self.assertEqual("1 - cola", (1 - F.col("cola")).sql()) + + def test_rmul(self): + self.assertEqual("1 * cola", (1 * F.col("cola")).sql()) + + def test_rdiv(self): + self.assertEqual("1 / cola", (1 / F.col("cola")).sql()) + + def test_pow(self): + self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql()) + + def test_rpow(self): + self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql()) + + def test_invert(self): + self.assertEqual("NOT cola", (~F.col("cola")).sql()) + + def test_startswith(self): + self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql()) + + def test_endswith(self): + self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql()) + + def test_rlike(self): + self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql()) + + def test_like(self): + self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql()) + + def test_ilike(self): + self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql()) + + def test_substring(self): + self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql()) + + def test_isin(self): + self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql()) + self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql()) + + def test_asc(self): + self.assertEqual("cola", F.col("cola").asc().sql()) + + def test_desc(self): + self.assertEqual("cola DESC", F.col("cola").desc().sql()) + + def test_asc_nulls_first(self): + self.assertEqual("cola", F.col("cola").asc_nulls_first().sql()) + + def test_asc_nulls_last(self): + self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql()) + + def test_desc_nulls_first(self): + self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql()) + + def test_desc_nulls_last(self): + self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql()) + + def test_when_otherwise(self): + self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) + self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", + (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), + ) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", + F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(), + ) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END", + F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(), + ) + + def test_is_null(self): + self.assertEqual("cola IS NULL", F.col("cola").isNull().sql()) + + def test_is_not_null(self): + self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql()) + + def test_cast(self): + self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql()) + + def test_alias(self): + self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql()) + + def test_between(self): + self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql()) + self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql()) + self.assertEqual( + "cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')", + F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), + ) + self.assertEqual( + "cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)", + F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), + ) + + def test_over(self): + over_rows = F.sum("cola").over( + Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing) + ) + self.assertEqual( + "SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + over_rows.sql(), + ) + over_range = F.sum("cola").over( + Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing) + ) + self.assertEqual( + "SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + over_range.sql(), + ) diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py new file mode 100644 index 0000000..c222cac --- /dev/null +++ b/tests/dataframe/unit/test_dataframe.py @@ -0,0 +1,39 @@ +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.dataframe import DataFrame +from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator + + +class TestDataframe(DataFrameSQLValidator): + def test_hash_select_expression(self): + expression = exp.select("cola").from_("table") + self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) + + def test_columns(self): + self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns) + + def test_cache(self): + df = self.df_employee.select("fname").cache() + expected_statements = [ + "DROP VIEW IF EXISTS t11623", + "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + ] + self.compare_sql(df, expected_statements) + + def test_persist_default(self): + df = self.df_employee.select("fname").persist() + expected_statements = [ + "DROP VIEW IF EXISTS t11623", + "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + ] + self.compare_sql(df, expected_statements) + + def test_persist_storagelevel(self): + df = self.df_employee.select("fname").persist("DISK_ONLY_2") + expected_statements = [ + "DROP VIEW IF EXISTS t11623", + "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + ] + self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py new file mode 100644 index 0000000..14b4a0a --- /dev/null +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -0,0 +1,86 @@ +from unittest import mock + +import sqlglot +from sqlglot.schema import MappingSchema +from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator + + +class TestDataFrameWriter(DataFrameSQLValidator): + def test_insertInto_full_path(self): + df = self.df_employee.write.insertInto("catalog.db.table_name") + expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_db_table(self): + df = self.df_employee.write.insertInto("db.table_name") + expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_table(self): + df = self.df_employee.write.insertInto("table_name") + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_overwrite(self): + df = self.df_employee.write.insertInto("table_name", overwrite=True) + expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_insertInto_byName(self): + sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) + df = self.df_employee.write.byName.insertInto("table_name") + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_cache(self): + df = self.df_employee.cache().write.insertInto("table_name") + expected_statements = [ + "DROP VIEW IF EXISTS t35612", + "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + ] + self.compare_sql(df, expected_statements) + + def test_saveAsTable_format(self): + with self.assertRaises(NotImplementedError): + self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0] + + def test_saveAsTable_append(self): + df = self.df_employee.write.saveAsTable("table_name", mode="append") + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_overwrite(self): + df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_error(self): + df = self.df_employee.write.saveAsTable("table_name", mode="error") + expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_ignore(self): + df = self.df_employee.write.saveAsTable("table_name", mode="ignore") + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_mode_standalone(self): + df = self.df_employee.write.mode("ignore").saveAsTable("table_name") + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_mode_override(self): + df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_cache(self): + df = self.df_employee.cache().write.saveAsTable("table_name") + expected_statements = [ + "DROP VIEW IF EXISTS t35612", + "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + ] + self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py new file mode 100644 index 0000000..10f3b57 --- /dev/null +++ b/tests/dataframe/unit/test_functions.py @@ -0,0 +1,1593 @@ +import datetime +import inspect +import unittest + +from sqlglot import expressions as exp +from sqlglot import parse_one +from sqlglot.dataframe.sql import functions as SF +from sqlglot.errors import ErrorLevel + + +class TestFunctions(unittest.TestCase): + @unittest.skip("not yet fixed.") + def test_invoke_anonymous(self): + for name, func in inspect.getmembers(SF, inspect.isfunction): + with self.subTest(f"{name} should not invoke anonymous_function"): + if "invoke_anonymous_function" in inspect.getsource(func): + func = parse_one(f"{name}()", read="spark", error_level=ErrorLevel.IGNORE) + self.assertIsInstance(func, exp.Anonymous) + + def test_lit(self): + test_str = SF.lit("test") + self.assertEqual("'test'", test_str.sql()) + test_int = SF.lit(30) + self.assertEqual("30", test_int.sql()) + test_float = SF.lit(10.10) + self.assertEqual("10.1", test_float.sql()) + test_bool = SF.lit(False) + self.assertEqual("FALSE", test_bool.sql()) + test_null = SF.lit(None) + self.assertEqual("NULL", test_null.sql()) + test_date = SF.lit(datetime.date(2022, 1, 1)) + self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) + test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) + self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + test_dict = SF.lit({"cola": 1, "colb": "test"}) + self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) + + def test_col(self): + test_col = SF.col("cola") + self.assertEqual("cola", test_col.sql()) + test_col_with_table = SF.col("table.cola") + self.assertEqual("table.cola", test_col_with_table.sql()) + test_col_on_col = SF.col(test_col) + self.assertEqual("cola", test_col_on_col.sql()) + test_int = SF.col(10) + self.assertEqual("10", test_int.sql()) + test_float = SF.col(10.10) + self.assertEqual("10.1", test_float.sql()) + test_bool = SF.col(True) + self.assertEqual("TRUE", test_bool.sql()) + test_array = SF.col([1, 2, "3"]) + self.assertEqual("ARRAY(1, 2, '3')", test_array.sql()) + test_date = SF.col(datetime.date(2022, 1, 1)) + self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) + test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) + self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + test_dict = SF.col({"cola": 1, "colb": "test"}) + self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) + + def test_asc(self): + asc_str = SF.asc("cola") + # ASC is removed from output since that is default so we can't check sql + self.assertIsInstance(asc_str.expression, exp.Ordered) + asc_col = SF.asc(SF.col("cola")) + self.assertIsInstance(asc_col.expression, exp.Ordered) + + def test_desc(self): + desc_str = SF.desc("cola") + self.assertEqual("cola DESC", desc_str.sql()) + desc_col = SF.desc(SF.col("cola")) + self.assertEqual("cola DESC", desc_col.sql()) + + def test_sqrt(self): + col_str = SF.sqrt("cola") + self.assertEqual("SQRT(cola)", col_str.sql()) + col = SF.sqrt(SF.col("cola")) + self.assertEqual("SQRT(cola)", col.sql()) + + def test_abs(self): + col_str = SF.abs("cola") + self.assertEqual("ABS(cola)", col_str.sql()) + col = SF.abs(SF.col("cola")) + self.assertEqual("ABS(cola)", col.sql()) + + def test_max(self): + col_str = SF.max("cola") + self.assertEqual("MAX(cola)", col_str.sql()) + col = SF.max(SF.col("cola")) + self.assertEqual("MAX(cola)", col.sql()) + + def test_min(self): + col_str = SF.min("cola") + self.assertEqual("MIN(cola)", col_str.sql()) + col = SF.min(SF.col("cola")) + self.assertEqual("MIN(cola)", col.sql()) + + def test_max_by(self): + col_str = SF.max_by("cola", "colb") + self.assertEqual("MAX_BY(cola, colb)", col_str.sql()) + col = SF.max_by(SF.col("cola"), SF.col("colb")) + self.assertEqual("MAX_BY(cola, colb)", col.sql()) + + def test_min_by(self): + col_str = SF.min_by("cola", "colb") + self.assertEqual("MIN_BY(cola, colb)", col_str.sql()) + col = SF.min_by(SF.col("cola"), SF.col("colb")) + self.assertEqual("MIN_BY(cola, colb)", col.sql()) + + def test_count(self): + col_str = SF.count("cola") + self.assertEqual("COUNT(cola)", col_str.sql()) + col = SF.count(SF.col("cola")) + self.assertEqual("COUNT(cola)", col.sql()) + + def test_sum(self): + col_str = SF.sum("cola") + self.assertEqual("SUM(cola)", col_str.sql()) + col = SF.sum(SF.col("cola")) + self.assertEqual("SUM(cola)", col.sql()) + + def test_avg(self): + col_str = SF.avg("cola") + self.assertEqual("AVG(cola)", col_str.sql()) + col = SF.avg(SF.col("cola")) + self.assertEqual("AVG(cola)", col.sql()) + + def test_mean(self): + col_str = SF.mean("cola") + self.assertEqual("MEAN(cola)", col_str.sql()) + col = SF.mean(SF.col("cola")) + self.assertEqual("MEAN(cola)", col.sql()) + + def test_sum_distinct(self): + with self.assertRaises(NotImplementedError): + SF.sum_distinct("cola") + with self.assertRaises(NotImplementedError): + SF.sumDistinct("cola") + + def test_product(self): + with self.assertRaises(NotImplementedError): + SF.product("cola") + with self.assertRaises(NotImplementedError): + SF.product("cola") + + def test_acos(self): + col_str = SF.acos("cola") + self.assertEqual("ACOS(cola)", col_str.sql()) + col = SF.acos(SF.col("cola")) + self.assertEqual("ACOS(cola)", col.sql()) + + def test_acosh(self): + col_str = SF.acosh("cola") + self.assertEqual("ACOSH(cola)", col_str.sql()) + col = SF.acosh(SF.col("cola")) + self.assertEqual("ACOSH(cola)", col.sql()) + + def test_asin(self): + col_str = SF.asin("cola") + self.assertEqual("ASIN(cola)", col_str.sql()) + col = SF.asin(SF.col("cola")) + self.assertEqual("ASIN(cola)", col.sql()) + + def test_asinh(self): + col_str = SF.asinh("cola") + self.assertEqual("ASINH(cola)", col_str.sql()) + col = SF.asinh(SF.col("cola")) + self.assertEqual("ASINH(cola)", col.sql()) + + def test_atan(self): + col_str = SF.atan("cola") + self.assertEqual("ATAN(cola)", col_str.sql()) + col = SF.atan(SF.col("cola")) + self.assertEqual("ATAN(cola)", col.sql()) + + def test_atan2(self): + col_str = SF.atan2("cola", "colb") + self.assertEqual("ATAN2(cola, colb)", col_str.sql()) + col = SF.atan2(SF.col("cola"), SF.col("colb")) + self.assertEqual("ATAN2(cola, colb)", col.sql()) + col_float = SF.atan2(10.10, "colb") + self.assertEqual("ATAN2(10.1, colb)", col_float.sql()) + col_float2 = SF.atan2("cola", 10.10) + self.assertEqual("ATAN2(cola, 10.1)", col_float2.sql()) + + def test_atanh(self): + col_str = SF.atanh("cola") + self.assertEqual("ATANH(cola)", col_str.sql()) + col = SF.atanh(SF.col("cola")) + self.assertEqual("ATANH(cola)", col.sql()) + + def test_cbrt(self): + col_str = SF.cbrt("cola") + self.assertEqual("CBRT(cola)", col_str.sql()) + col = SF.cbrt(SF.col("cola")) + self.assertEqual("CBRT(cola)", col.sql()) + + def test_ceil(self): + col_str = SF.ceil("cola") + self.assertEqual("CEIL(cola)", col_str.sql()) + col = SF.ceil(SF.col("cola")) + self.assertEqual("CEIL(cola)", col.sql()) + + def test_cos(self): + col_str = SF.cos("cola") + self.assertEqual("COS(cola)", col_str.sql()) + col = SF.cos(SF.col("cola")) + self.assertEqual("COS(cola)", col.sql()) + + def test_cosh(self): + col_str = SF.cosh("cola") + self.assertEqual("COSH(cola)", col_str.sql()) + col = SF.cosh(SF.col("cola")) + self.assertEqual("COSH(cola)", col.sql()) + + def test_cot(self): + col_str = SF.cot("cola") + self.assertEqual("COT(cola)", col_str.sql()) + col = SF.cot(SF.col("cola")) + self.assertEqual("COT(cola)", col.sql()) + + def test_csc(self): + col_str = SF.csc("cola") + self.assertEqual("CSC(cola)", col_str.sql()) + col = SF.csc(SF.col("cola")) + self.assertEqual("CSC(cola)", col.sql()) + + def test_exp(self): + col_str = SF.exp("cola") + self.assertEqual("EXP(cola)", col_str.sql()) + col = SF.exp(SF.col("cola")) + self.assertEqual("EXP(cola)", col.sql()) + + def test_expm1(self): + col_str = SF.expm1("cola") + self.assertEqual("EXPM1(cola)", col_str.sql()) + col = SF.expm1(SF.col("cola")) + self.assertEqual("EXPM1(cola)", col.sql()) + + def test_floor(self): + col_str = SF.floor("cola") + self.assertEqual("FLOOR(cola)", col_str.sql()) + col = SF.floor(SF.col("cola")) + self.assertEqual("FLOOR(cola)", col.sql()) + + def test_log(self): + col_str = SF.log("cola") + self.assertEqual("LN(cola)", col_str.sql()) + col = SF.log(SF.col("cola")) + self.assertEqual("LN(cola)", col.sql()) + col_arg = SF.log(10.0, "age") + self.assertEqual("LOG(10.0, age)", col_arg.sql()) + + def test_log10(self): + col_str = SF.log10("cola") + self.assertEqual("LOG10(cola)", col_str.sql()) + col = SF.log10(SF.col("cola")) + self.assertEqual("LOG10(cola)", col.sql()) + + def test_log1p(self): + col_str = SF.log1p("cola") + self.assertEqual("LOG1P(cola)", col_str.sql()) + col = SF.log1p(SF.col("cola")) + self.assertEqual("LOG1P(cola)", col.sql()) + + def test_log2(self): + col_str = SF.log2("cola") + self.assertEqual("LOG2(cola)", col_str.sql()) + col = SF.log2(SF.col("cola")) + self.assertEqual("LOG2(cola)", col.sql()) + + def test_rint(self): + col_str = SF.rint("cola") + self.assertEqual("RINT(cola)", col_str.sql()) + col = SF.rint(SF.col("cola")) + self.assertEqual("RINT(cola)", col.sql()) + + def test_sec(self): + col_str = SF.sec("cola") + self.assertEqual("SEC(cola)", col_str.sql()) + col = SF.sec(SF.col("cola")) + self.assertEqual("SEC(cola)", col.sql()) + + def test_signum(self): + col_str = SF.signum("cola") + self.assertEqual("SIGNUM(cola)", col_str.sql()) + col = SF.signum(SF.col("cola")) + self.assertEqual("SIGNUM(cola)", col.sql()) + + def test_sin(self): + col_str = SF.sin("cola") + self.assertEqual("SIN(cola)", col_str.sql()) + col = SF.sin(SF.col("cola")) + self.assertEqual("SIN(cola)", col.sql()) + + def test_sinh(self): + col_str = SF.sinh("cola") + self.assertEqual("SINH(cola)", col_str.sql()) + col = SF.sinh(SF.col("cola")) + self.assertEqual("SINH(cola)", col.sql()) + + def test_tan(self): + col_str = SF.tan("cola") + self.assertEqual("TAN(cola)", col_str.sql()) + col = SF.tan(SF.col("cola")) + self.assertEqual("TAN(cola)", col.sql()) + + def test_tanh(self): + col_str = SF.tanh("cola") + self.assertEqual("TANH(cola)", col_str.sql()) + col = SF.tanh(SF.col("cola")) + self.assertEqual("TANH(cola)", col.sql()) + + def test_degrees(self): + col_str = SF.degrees("cola") + self.assertEqual("DEGREES(cola)", col_str.sql()) + col = SF.degrees(SF.col("cola")) + self.assertEqual("DEGREES(cola)", col.sql()) + col_legacy = SF.toDegrees(SF.col("cola")) + self.assertEqual("DEGREES(cola)", col_legacy.sql()) + + def test_radians(self): + col_str = SF.radians("cola") + self.assertEqual("RADIANS(cola)", col_str.sql()) + col = SF.radians(SF.col("cola")) + self.assertEqual("RADIANS(cola)", col.sql()) + col_legacy = SF.toRadians(SF.col("cola")) + self.assertEqual("RADIANS(cola)", col_legacy.sql()) + + def test_bitwise_not(self): + col_str = SF.bitwise_not("cola") + self.assertEqual("~cola", col_str.sql()) + col = SF.bitwise_not(SF.col("cola")) + self.assertEqual("~cola", col.sql()) + col_legacy = SF.bitwiseNOT(SF.col("cola")) + self.assertEqual("~cola", col_legacy.sql()) + + def test_asc_nulls_first(self): + col_str = SF.asc_nulls_first("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola", col_str.sql()) + col = SF.asc_nulls_first(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola", col.sql()) + + def test_asc_nulls_last(self): + col_str = SF.asc_nulls_last("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola NULLS LAST", col_str.sql()) + col = SF.asc_nulls_last(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola NULLS LAST", col.sql()) + + def test_desc_nulls_first(self): + col_str = SF.desc_nulls_first("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola DESC NULLS FIRST", col_str.sql()) + col = SF.desc_nulls_first(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola DESC NULLS FIRST", col.sql()) + + def test_desc_nulls_last(self): + col_str = SF.desc_nulls_last("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola DESC", col_str.sql()) + col = SF.desc_nulls_last(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola DESC", col.sql()) + + def test_stddev(self): + col_str = SF.stddev("cola") + self.assertEqual("STDDEV(cola)", col_str.sql()) + col = SF.stddev(SF.col("cola")) + self.assertEqual("STDDEV(cola)", col.sql()) + + def test_stddev_samp(self): + col_str = SF.stddev_samp("cola") + self.assertEqual("STDDEV_SAMP(cola)", col_str.sql()) + col = SF.stddev_samp(SF.col("cola")) + self.assertEqual("STDDEV_SAMP(cola)", col.sql()) + + def test_stddev_pop(self): + col_str = SF.stddev_pop("cola") + self.assertEqual("STDDEV_POP(cola)", col_str.sql()) + col = SF.stddev_pop(SF.col("cola")) + self.assertEqual("STDDEV_POP(cola)", col.sql()) + + def test_variance(self): + col_str = SF.variance("cola") + self.assertEqual("VARIANCE(cola)", col_str.sql()) + col = SF.variance(SF.col("cola")) + self.assertEqual("VARIANCE(cola)", col.sql()) + + def test_var_samp(self): + col_str = SF.var_samp("cola") + self.assertEqual("VARIANCE(cola)", col_str.sql()) + col = SF.var_samp(SF.col("cola")) + self.assertEqual("VARIANCE(cola)", col.sql()) + + def test_var_pop(self): + col_str = SF.var_pop("cola") + self.assertEqual("VAR_POP(cola)", col_str.sql()) + col = SF.var_pop(SF.col("cola")) + self.assertEqual("VAR_POP(cola)", col.sql()) + + def test_skewness(self): + col_str = SF.skewness("cola") + self.assertEqual("SKEWNESS(cola)", col_str.sql()) + col = SF.skewness(SF.col("cola")) + self.assertEqual("SKEWNESS(cola)", col.sql()) + + def test_kurtosis(self): + col_str = SF.kurtosis("cola") + self.assertEqual("KURTOSIS(cola)", col_str.sql()) + col = SF.kurtosis(SF.col("cola")) + self.assertEqual("KURTOSIS(cola)", col.sql()) + + def test_collect_list(self): + col_str = SF.collect_list("cola") + self.assertEqual("COLLECT_LIST(cola)", col_str.sql()) + col = SF.collect_list(SF.col("cola")) + self.assertEqual("COLLECT_LIST(cola)", col.sql()) + + def test_collect_set(self): + col_str = SF.collect_set("cola") + self.assertEqual("COLLECT_SET(cola)", col_str.sql()) + col = SF.collect_set(SF.col("cola")) + self.assertEqual("COLLECT_SET(cola)", col.sql()) + + def test_hypot(self): + col_str = SF.hypot("cola", "colb") + self.assertEqual("HYPOT(cola, colb)", col_str.sql()) + col = SF.hypot(SF.col("cola"), SF.col("colb")) + self.assertEqual("HYPOT(cola, colb)", col.sql()) + col_float = SF.hypot(10.10, "colb") + self.assertEqual("HYPOT(10.1, colb)", col_float.sql()) + col_float2 = SF.hypot("cola", 10.10) + self.assertEqual("HYPOT(cola, 10.1)", col_float2.sql()) + + def test_pow(self): + col_str = SF.pow("cola", "colb") + self.assertEqual("POW(cola, colb)", col_str.sql()) + col = SF.pow(SF.col("cola"), SF.col("colb")) + self.assertEqual("POW(cola, colb)", col.sql()) + col_float = SF.pow(10.10, "colb") + self.assertEqual("POW(10.1, colb)", col_float.sql()) + col_float2 = SF.pow("cola", 10.10) + self.assertEqual("POW(cola, 10.1)", col_float2.sql()) + + def test_row_number(self): + col_str = SF.row_number() + self.assertEqual("ROW_NUMBER()", col_str.sql()) + col = SF.row_number() + self.assertEqual("ROW_NUMBER()", col.sql()) + + def test_dense_rank(self): + col_str = SF.dense_rank() + self.assertEqual("DENSE_RANK()", col_str.sql()) + col = SF.dense_rank() + self.assertEqual("DENSE_RANK()", col.sql()) + + def test_rank(self): + col_str = SF.rank() + self.assertEqual("RANK()", col_str.sql()) + col = SF.rank() + self.assertEqual("RANK()", col.sql()) + + def test_cume_dist(self): + col_str = SF.cume_dist() + self.assertEqual("CUME_DIST()", col_str.sql()) + col = SF.cume_dist() + self.assertEqual("CUME_DIST()", col.sql()) + + def test_percent_rank(self): + col_str = SF.percent_rank() + self.assertEqual("PERCENT_RANK()", col_str.sql()) + col = SF.percent_rank() + self.assertEqual("PERCENT_RANK()", col.sql()) + + def test_approx_count_distinct(self): + col_str = SF.approx_count_distinct("cola") + self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col_str.sql()) + col_str_with_accuracy = SF.approx_count_distinct("cola", 0.05) + self.assertEqual("APPROX_COUNT_DISTINCT(cola, 0.05)", col_str_with_accuracy.sql()) + col = SF.approx_count_distinct(SF.col("cola")) + self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col.sql()) + col_with_accuracy = SF.approx_count_distinct(SF.col("cola"), 0.05) + self.assertEqual("APPROX_COUNT_DISTINCT(cola, 0.05)", col_with_accuracy.sql()) + col_legacy = SF.approxCountDistinct(SF.col("cola")) + self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col_legacy.sql()) + + def test_coalesce(self): + col_str = SF.coalesce("cola", "colb", "colc") + self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql()) + col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc")) + self.assertEqual("COALESCE(cola, colb, colc)", col.sql()) + + def test_corr(self): + col_str = SF.corr("cola", "colb") + self.assertEqual("CORR(cola, colb)", col_str.sql()) + col = SF.corr(SF.col("cola"), "colb") + self.assertEqual("CORR(cola, colb)", col.sql()) + + def test_covar_pop(self): + col_str = SF.covar_pop("cola", "colb") + self.assertEqual("COVAR_POP(cola, colb)", col_str.sql()) + col = SF.covar_pop(SF.col("cola"), "colb") + self.assertEqual("COVAR_POP(cola, colb)", col.sql()) + + def test_covar_samp(self): + col_str = SF.covar_samp("cola", "colb") + self.assertEqual("COVAR_SAMP(cola, colb)", col_str.sql()) + col = SF.covar_samp(SF.col("cola"), "colb") + self.assertEqual("COVAR_SAMP(cola, colb)", col.sql()) + + def test_count_distinct(self): + col_str = SF.count_distinct("cola") + self.assertEqual("COUNT(DISTINCT cola)", col_str.sql()) + col = SF.count_distinct(SF.col("cola")) + self.assertEqual("COUNT(DISTINCT cola)", col.sql()) + col_legacy = SF.countDistinct(SF.col("cola")) + self.assertEqual("COUNT(DISTINCT cola)", col_legacy.sql()) + col_multiple = SF.count_distinct(SF.col("cola"), SF.col("colb")) + self.assertEqual("COUNT(DISTINCT cola, colb)", col_multiple.sql()) + + def test_first(self): + col_str = SF.first("cola") + self.assertEqual("FIRST(cola)", col_str.sql()) + col = SF.first(SF.col("cola")) + self.assertEqual("FIRST(cola)", col.sql()) + ignore_nulls = SF.first("cola", True) + self.assertEqual("FIRST(cola, TRUE)", ignore_nulls.sql()) + + def test_grouping_id(self): + col_str = SF.grouping_id("cola", "colb") + self.assertEqual("GROUPING_ID(cola, colb)", col_str.sql()) + col = SF.grouping_id(SF.col("cola"), SF.col("colb")) + self.assertEqual("GROUPING_ID(cola, colb)", col.sql()) + col_grouping_no_arg = SF.grouping_id() + self.assertEqual("GROUPING_ID()", col_grouping_no_arg.sql()) + col_grouping_single_arg = SF.grouping_id("cola") + self.assertEqual("GROUPING_ID(cola)", col_grouping_single_arg.sql()) + + def test_input_file_name(self): + col = SF.input_file_name() + self.assertEqual("INPUT_FILE_NAME()", col.sql()) + + def test_isnan(self): + col_str = SF.isnan("cola") + self.assertEqual("ISNAN(cola)", col_str.sql()) + col = SF.isnan(SF.col("cola")) + self.assertEqual("ISNAN(cola)", col.sql()) + + def test_isnull(self): + col_str = SF.isnull("cola") + self.assertEqual("ISNULL(cola)", col_str.sql()) + col = SF.isnull(SF.col("cola")) + self.assertEqual("ISNULL(cola)", col.sql()) + + def test_last(self): + col_str = SF.last("cola") + self.assertEqual("LAST(cola)", col_str.sql()) + col = SF.last(SF.col("cola")) + self.assertEqual("LAST(cola)", col.sql()) + ignore_nulls = SF.last("cola", True) + self.assertEqual("LAST(cola, TRUE)", ignore_nulls.sql()) + + def test_monotonically_increasing_id(self): + col = SF.monotonically_increasing_id() + self.assertEqual("MONOTONICALLY_INCREASING_ID()", col.sql()) + + def test_nanvl(self): + col_str = SF.nanvl("cola", "colb") + self.assertEqual("NANVL(cola, colb)", col_str.sql()) + col = SF.nanvl(SF.col("cola"), SF.col("colb")) + self.assertEqual("NANVL(cola, colb)", col.sql()) + + def test_percentile_approx(self): + col_str = SF.percentile_approx("cola", [0.5, 0.4, 0.1]) + self.assertEqual("PERCENTILE_APPROX(cola, ARRAY(0.5, 0.4, 0.1))", col_str.sql()) + col = SF.percentile_approx(SF.col("cola"), [0.5, 0.4, 0.1]) + self.assertEqual("PERCENTILE_APPROX(cola, ARRAY(0.5, 0.4, 0.1))", col.sql()) + col_accuracy = SF.percentile_approx("cola", 0.1, 100) + self.assertEqual("PERCENTILE_APPROX(cola, 0.1, 100)", col_accuracy.sql()) + + def test_rand(self): + col_str = SF.rand(SF.lit(0)) + self.assertEqual("RAND(0)", col_str.sql()) + col = SF.rand(SF.lit(0)) + self.assertEqual("RAND(0)", col.sql()) + no_col = SF.rand() + self.assertEqual("RAND()", no_col.sql()) + + def test_randn(self): + col_str = SF.randn(0) + self.assertEqual("RANDN(0)", col_str.sql()) + col = SF.randn(0) + self.assertEqual("RANDN(0)", col.sql()) + no_col = SF.randn() + self.assertEqual("RANDN()", no_col.sql()) + + def test_round(self): + col_str = SF.round("cola", 0) + self.assertEqual("ROUND(cola, 0)", col_str.sql()) + col = SF.round(SF.col("cola"), 0) + self.assertEqual("ROUND(cola, 0)", col.sql()) + col_no_scale = SF.round("cola") + self.assertEqual("ROUND(cola)", col_no_scale.sql()) + + def test_bround(self): + col_str = SF.bround("cola", 0) + self.assertEqual("BROUND(cola, 0)", col_str.sql()) + col = SF.bround(SF.col("cola"), 0) + self.assertEqual("BROUND(cola, 0)", col.sql()) + col_no_scale = SF.bround("cola") + self.assertEqual("BROUND(cola)", col_no_scale.sql()) + + def test_shiftleft(self): + col_str = SF.shiftleft("cola", 1) + self.assertEqual("SHIFTLEFT(cola, 1)", col_str.sql()) + col = SF.shiftleft(SF.col("cola"), 1) + self.assertEqual("SHIFTLEFT(cola, 1)", col.sql()) + col_legacy = SF.shiftLeft(SF.col("cola"), 1) + self.assertEqual("SHIFTLEFT(cola, 1)", col_legacy.sql()) + + def test_shiftright(self): + col_str = SF.shiftright("cola", 1) + self.assertEqual("SHIFTRIGHT(cola, 1)", col_str.sql()) + col = SF.shiftright(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHT(cola, 1)", col.sql()) + col_legacy = SF.shiftRight(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHT(cola, 1)", col_legacy.sql()) + + def test_shiftrightunsigned(self): + col_str = SF.shiftrightunsigned("cola", 1) + self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col_str.sql()) + col = SF.shiftrightunsigned(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col.sql()) + col_legacy = SF.shiftRightUnsigned(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col_legacy.sql()) + + def test_expr(self): + col_str = SF.expr("LENGTH(name)") + self.assertEqual("LENGTH(name)", col_str.sql()) + + def test_struct(self): + col_str = SF.struct("cola", "colb", "colc") + self.assertEqual("STRUCT(cola, colb, colc)", col_str.sql()) + col = SF.struct(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("STRUCT(cola, colb, colc)", col.sql()) + col_single = SF.struct("cola") + self.assertEqual("STRUCT(cola)", col_single.sql()) + col_list = SF.struct(["cola", "colb", "colc"]) + self.assertEqual("STRUCT(cola, colb, colc)", col_list.sql()) + + def test_greatest(self): + single_str = SF.greatest("cola") + self.assertEqual("GREATEST(cola)", single_str.sql()) + single_col = SF.greatest(SF.col("cola")) + self.assertEqual("GREATEST(cola)", single_col.sql()) + multiple_mix = SF.greatest("col1", "col2", SF.col("col3"), SF.col("col4")) + self.assertEqual("GREATEST(col1, col2, col3, col4)", multiple_mix.sql()) + + def test_least(self): + single_str = SF.least("cola") + self.assertEqual("LEAST(cola)", single_str.sql()) + single_col = SF.least(SF.col("cola")) + self.assertEqual("LEAST(cola)", single_col.sql()) + multiple_mix = SF.least("col1", "col2", SF.col("col3"), SF.col("col4")) + self.assertEqual("LEAST(col1, col2, col3, col4)", multiple_mix.sql()) + + def test_when(self): + col_simple = SF.when(SF.col("cola") == 2, 1) + self.assertEqual("CASE WHEN cola = 2 THEN 1 END", col_simple.sql()) + col_complex = SF.when(SF.col("cola") == 2, SF.col("colb") + 2) + self.assertEqual("CASE WHEN cola = 2 THEN colb + 2 END", col_complex.sql()) + + def test_conv(self): + col_str = SF.conv("cola", 2, 16) + self.assertEqual("CONV(cola, 2, 16)", col_str.sql()) + col = SF.conv(SF.col("cola"), 2, 16) + self.assertEqual("CONV(cola, 2, 16)", col.sql()) + + def test_factorial(self): + col_str = SF.factorial("cola") + self.assertEqual("FACTORIAL(cola)", col_str.sql()) + col = SF.factorial(SF.col("cola")) + self.assertEqual("FACTORIAL(cola)", col.sql()) + + def test_lag(self): + col_str = SF.lag("cola", 3, "colc") + self.assertEqual("LAG(cola, 3, colc)", col_str.sql()) + col = SF.lag(SF.col("cola"), 3, "colc") + self.assertEqual("LAG(cola, 3, colc)", col.sql()) + col_no_default = SF.lag("cola", 3) + self.assertEqual("LAG(cola, 3)", col_no_default.sql()) + col_no_offset = SF.lag("cola") + self.assertEqual("LAG(cola)", col_no_offset.sql()) + + def test_lead(self): + col_str = SF.lead("cola", 3, "colc") + self.assertEqual("LEAD(cola, 3, colc)", col_str.sql()) + col = SF.lead(SF.col("cola"), 3, "colc") + self.assertEqual("LEAD(cola, 3, colc)", col.sql()) + col_no_default = SF.lead("cola", 3) + self.assertEqual("LEAD(cola, 3)", col_no_default.sql()) + col_no_offset = SF.lead("cola") + self.assertEqual("LEAD(cola)", col_no_offset.sql()) + + def test_nth_value(self): + col_str = SF.nth_value("cola", 3) + self.assertEqual("NTH_VALUE(cola, 3)", col_str.sql()) + col = SF.nth_value(SF.col("cola"), 3) + self.assertEqual("NTH_VALUE(cola, 3)", col.sql()) + col_no_offset = SF.nth_value("cola") + self.assertEqual("NTH_VALUE(cola)", col_no_offset.sql()) + with self.assertRaises(NotImplementedError): + SF.nth_value("cola", ignoreNulls=True) + + def test_ntile(self): + col = SF.ntile(2) + self.assertEqual("NTILE(2)", col.sql()) + + def test_current_date(self): + col = SF.current_date() + self.assertEqual("CURRENT_DATE", col.sql()) + + def test_current_timestamp(self): + col = SF.current_timestamp() + self.assertEqual("CURRENT_TIMESTAMP()", col.sql()) + + def test_date_format(self): + col_str = SF.date_format("cola", "MM/dd/yyy") + self.assertEqual("DATE_FORMAT(cola, 'MM/dd/yyy')", col_str.sql()) + col = SF.date_format(SF.col("cola"), "MM/dd/yyy") + self.assertEqual("DATE_FORMAT(cola, 'MM/dd/yyy')", col.sql()) + + def test_year(self): + col_str = SF.year("cola") + self.assertEqual("YEAR(cola)", col_str.sql()) + col = SF.year(SF.col("cola")) + self.assertEqual("YEAR(cola)", col.sql()) + + def test_quarter(self): + col_str = SF.quarter("cola") + self.assertEqual("QUARTER(cola)", col_str.sql()) + col = SF.quarter(SF.col("cola")) + self.assertEqual("QUARTER(cola)", col.sql()) + + def test_month(self): + col_str = SF.month("cola") + self.assertEqual("MONTH(cola)", col_str.sql()) + col = SF.month(SF.col("cola")) + self.assertEqual("MONTH(cola)", col.sql()) + + def test_dayofweek(self): + col_str = SF.dayofweek("cola") + self.assertEqual("DAYOFWEEK(cola)", col_str.sql()) + col = SF.dayofweek(SF.col("cola")) + self.assertEqual("DAYOFWEEK(cola)", col.sql()) + + def test_dayofmonth(self): + col_str = SF.dayofmonth("cola") + self.assertEqual("DAYOFMONTH(cola)", col_str.sql()) + col = SF.dayofmonth(SF.col("cola")) + self.assertEqual("DAYOFMONTH(cola)", col.sql()) + + def test_dayofyear(self): + col_str = SF.dayofyear("cola") + self.assertEqual("DAYOFYEAR(cola)", col_str.sql()) + col = SF.dayofyear(SF.col("cola")) + self.assertEqual("DAYOFYEAR(cola)", col.sql()) + + def test_hour(self): + col_str = SF.hour("cola") + self.assertEqual("HOUR(cola)", col_str.sql()) + col = SF.hour(SF.col("cola")) + self.assertEqual("HOUR(cola)", col.sql()) + + def test_minute(self): + col_str = SF.minute("cola") + self.assertEqual("MINUTE(cola)", col_str.sql()) + col = SF.minute(SF.col("cola")) + self.assertEqual("MINUTE(cola)", col.sql()) + + def test_second(self): + col_str = SF.second("cola") + self.assertEqual("SECOND(cola)", col_str.sql()) + col = SF.second(SF.col("cola")) + self.assertEqual("SECOND(cola)", col.sql()) + + def test_weekofyear(self): + col_str = SF.weekofyear("cola") + self.assertEqual("WEEKOFYEAR(cola)", col_str.sql()) + col = SF.weekofyear(SF.col("cola")) + self.assertEqual("WEEKOFYEAR(cola)", col.sql()) + + def test_make_date(self): + col_str = SF.make_date("cola", "colb", "colc") + self.assertEqual("MAKE_DATE(cola, colb, colc)", col_str.sql()) + col = SF.make_date(SF.col("cola"), SF.col("colb"), "colc") + self.assertEqual("MAKE_DATE(cola, colb, colc)", col.sql()) + + def test_date_add(self): + col_str = SF.date_add("cola", 2) + self.assertEqual("DATE_ADD(cola, 2)", col_str.sql()) + col = SF.date_add(SF.col("cola"), 2) + self.assertEqual("DATE_ADD(cola, 2)", col.sql()) + col_col_for_add = SF.date_add("cola", "colb") + self.assertEqual("DATE_ADD(cola, colb)", col_col_for_add.sql()) + + def test_date_sub(self): + col_str = SF.date_sub("cola", 2) + self.assertEqual("DATE_SUB(cola, 2)", col_str.sql()) + col = SF.date_sub(SF.col("cola"), 2) + self.assertEqual("DATE_SUB(cola, 2)", col.sql()) + col_col_for_add = SF.date_sub("cola", "colb") + self.assertEqual("DATE_SUB(cola, colb)", col_col_for_add.sql()) + + def test_date_diff(self): + col_str = SF.date_diff("cola", "colb") + self.assertEqual("DATEDIFF(cola, colb)", col_str.sql()) + col = SF.date_diff(SF.col("cola"), SF.col("colb")) + self.assertEqual("DATEDIFF(cola, colb)", col.sql()) + + def test_add_months(self): + col_str = SF.add_months("cola", 2) + self.assertEqual("ADD_MONTHS(cola, 2)", col_str.sql()) + col = SF.add_months(SF.col("cola"), 2) + self.assertEqual("ADD_MONTHS(cola, 2)", col.sql()) + col_col_for_add = SF.add_months("cola", "colb") + self.assertEqual("ADD_MONTHS(cola, colb)", col_col_for_add.sql()) + + def test_months_between(self): + col_str = SF.months_between("cola", "colb") + self.assertEqual("MONTHS_BETWEEN(cola, colb)", col_str.sql()) + col = SF.months_between(SF.col("cola"), SF.col("colb")) + self.assertEqual("MONTHS_BETWEEN(cola, colb)", col.sql()) + col_round_off = SF.months_between("cola", "colb", True) + self.assertEqual("MONTHS_BETWEEN(cola, colb, TRUE)", col_round_off.sql()) + + def test_to_date(self): + col_str = SF.to_date("cola") + self.assertEqual("TO_DATE(cola)", col_str.sql()) + col = SF.to_date(SF.col("cola")) + self.assertEqual("TO_DATE(cola)", col.sql()) + col_with_format = SF.to_date("cola", "yyyy-MM-dd") + self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql()) + + def test_to_timestamp(self): + col_str = SF.to_timestamp("cola") + self.assertEqual("TO_TIMESTAMP(cola)", col_str.sql()) + col = SF.to_timestamp(SF.col("cola")) + self.assertEqual("TO_TIMESTAMP(cola)", col.sql()) + col_with_format = SF.to_timestamp("cola", "yyyy-MM-dd") + self.assertEqual("TO_TIMESTAMP(cola, 'yyyy-MM-dd')", col_with_format.sql()) + + def test_trunc(self): + col_str = SF.trunc("cola", "year") + self.assertEqual("TRUNC(cola, 'year')", col_str.sql()) + col = SF.trunc(SF.col("cola"), "year") + self.assertEqual("TRUNC(cola, 'year')", col.sql()) + + def test_date_trunc(self): + col_str = SF.date_trunc("year", "cola") + self.assertEqual("DATE_TRUNC('year', cola)", col_str.sql()) + col = SF.date_trunc("year", SF.col("cola")) + self.assertEqual("DATE_TRUNC('year', cola)", col.sql()) + + def test_next_day(self): + col_str = SF.next_day("cola", "Mon") + self.assertEqual("NEXT_DAY(cola, 'Mon')", col_str.sql()) + col = SF.next_day(SF.col("cola"), "Mon") + self.assertEqual("NEXT_DAY(cola, 'Mon')", col.sql()) + + def test_last_day(self): + col_str = SF.last_day("cola") + self.assertEqual("LAST_DAY(cola)", col_str.sql()) + col = SF.last_day(SF.col("cola")) + self.assertEqual("LAST_DAY(cola)", col.sql()) + + def test_from_unixtime(self): + col_str = SF.from_unixtime("cola") + self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql()) + col = SF.from_unixtime(SF.col("cola")) + self.assertEqual("FROM_UNIXTIME(cola)", col.sql()) + col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss") + self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + + def test_unix_timestamp(self): + col_str = SF.unix_timestamp("cola") + self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql()) + col = SF.unix_timestamp(SF.col("cola")) + self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql()) + col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss") + self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_current = SF.unix_timestamp() + self.assertEqual("UNIX_TIMESTAMP()", col_current.sql()) + + def test_from_utc_timestamp(self): + col_str = SF.from_utc_timestamp("cola", "PST") + self.assertEqual("FROM_UTC_TIMESTAMP(cola, 'PST')", col_str.sql()) + col = SF.from_utc_timestamp(SF.col("cola"), "PST") + self.assertEqual("FROM_UTC_TIMESTAMP(cola, 'PST')", col.sql()) + col_col = SF.from_utc_timestamp("cola", SF.col("colb")) + self.assertEqual("FROM_UTC_TIMESTAMP(cola, colb)", col_col.sql()) + + def test_to_utc_timestamp(self): + col_str = SF.to_utc_timestamp("cola", "PST") + self.assertEqual("TO_UTC_TIMESTAMP(cola, 'PST')", col_str.sql()) + col = SF.to_utc_timestamp(SF.col("cola"), "PST") + self.assertEqual("TO_UTC_TIMESTAMP(cola, 'PST')", col.sql()) + col_col = SF.to_utc_timestamp("cola", SF.col("colb")) + self.assertEqual("TO_UTC_TIMESTAMP(cola, colb)", col_col.sql()) + + def test_timestamp_seconds(self): + col_str = SF.timestamp_seconds("cola") + self.assertEqual("TIMESTAMP_SECONDS(cola)", col_str.sql()) + col = SF.timestamp_seconds(SF.col("cola")) + self.assertEqual("TIMESTAMP_SECONDS(cola)", col.sql()) + + def test_window(self): + col_str = SF.window("cola", "10 minutes") + self.assertEqual("WINDOW(cola, '10 minutes')", col_str.sql()) + col = SF.window(SF.col("cola"), "10 minutes") + self.assertEqual("WINDOW(cola, '10 minutes')", col.sql()) + col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds") + self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()) + col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds") + self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()) + col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds") + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql() + ) + + def test_session_window(self): + col_str = SF.session_window("cola", "5 seconds") + self.assertEqual("SESSION_WINDOW(cola, '5 seconds')", col_str.sql()) + col = SF.session_window(SF.col("cola"), SF.lit("5 seconds")) + self.assertEqual("SESSION_WINDOW(cola, '5 seconds')", col.sql()) + + def test_crc32(self): + col_str = SF.crc32("Spark") + self.assertEqual("CRC32('Spark')", col_str.sql()) + col = SF.crc32(SF.col("cola")) + self.assertEqual("CRC32(cola)", col.sql()) + + def test_md5(self): + col_str = SF.md5("Spark") + self.assertEqual("MD5('Spark')", col_str.sql()) + col = SF.md5(SF.col("cola")) + self.assertEqual("MD5(cola)", col.sql()) + + def test_sha1(self): + col_str = SF.sha1("Spark") + self.assertEqual("SHA1('Spark')", col_str.sql()) + col = SF.sha1(SF.col("cola")) + self.assertEqual("SHA1(cola)", col.sql()) + + def test_sha2(self): + col_str = SF.sha2("Spark", 256) + self.assertEqual("SHA2('Spark', 256)", col_str.sql()) + col = SF.sha2(SF.col("cola"), 256) + self.assertEqual("SHA2(cola, 256)", col.sql()) + + def test_hash(self): + col_str = SF.hash("cola", "colb", "colc") + self.assertEqual("HASH(cola, colb, colc)", col_str.sql()) + col = SF.hash(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("HASH(cola, colb, colc)", col.sql()) + + def test_xxhash64(self): + col_str = SF.xxhash64("cola", "colb", "colc") + self.assertEqual("XXHASH64(cola, colb, colc)", col_str.sql()) + col = SF.xxhash64(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("XXHASH64(cola, colb, colc)", col.sql()) + + def test_assert_true(self): + col = SF.assert_true(SF.col("cola") < SF.col("colb")) + self.assertEqual("ASSERT_TRUE(cola < colb)", col.sql()) + col_error_msg_col = SF.assert_true(SF.col("cola") < SF.col("colb"), SF.col("colc")) + self.assertEqual("ASSERT_TRUE(cola < colb, colc)", col_error_msg_col.sql()) + col_error_msg_lit = SF.assert_true(SF.col("cola") < SF.col("colb"), "error") + self.assertEqual("ASSERT_TRUE(cola < colb, 'error')", col_error_msg_lit.sql()) + + def test_raise_error(self): + col_str = SF.raise_error("custom error message") + self.assertEqual("RAISE_ERROR('custom error message')", col_str.sql()) + col = SF.raise_error(SF.col("cola")) + self.assertEqual("RAISE_ERROR(cola)", col.sql()) + + def test_upper(self): + col_str = SF.upper("cola") + self.assertEqual("UPPER(cola)", col_str.sql()) + col = SF.upper(SF.col("cola")) + self.assertEqual("UPPER(cola)", col.sql()) + + def test_lower(self): + col_str = SF.lower("cola") + self.assertEqual("LOWER(cola)", col_str.sql()) + col = SF.lower(SF.col("cola")) + self.assertEqual("LOWER(cola)", col.sql()) + + def test_ascii(self): + col_str = SF.ascii(SF.lit(2)) + self.assertEqual("ASCII(2)", col_str.sql()) + col = SF.ascii(SF.col("cola")) + self.assertEqual("ASCII(cola)", col.sql()) + + def test_base64(self): + col_str = SF.base64(SF.lit(2)) + self.assertEqual("BASE64(2)", col_str.sql()) + col = SF.base64(SF.col("cola")) + self.assertEqual("BASE64(cola)", col.sql()) + + def test_unbase64(self): + col_str = SF.unbase64(SF.lit(2)) + self.assertEqual("UNBASE64(2)", col_str.sql()) + col = SF.unbase64(SF.col("cola")) + self.assertEqual("UNBASE64(cola)", col.sql()) + + def test_ltrim(self): + col_str = SF.ltrim(SF.lit("Spark")) + self.assertEqual("LTRIM('Spark')", col_str.sql()) + col = SF.ltrim(SF.col("cola")) + self.assertEqual("LTRIM(cola)", col.sql()) + + def test_rtrim(self): + col_str = SF.rtrim(SF.lit("Spark")) + self.assertEqual("RTRIM('Spark')", col_str.sql()) + col = SF.rtrim(SF.col("cola")) + self.assertEqual("RTRIM(cola)", col.sql()) + + def test_trim(self): + col_str = SF.trim(SF.lit("Spark")) + self.assertEqual("TRIM('Spark')", col_str.sql()) + col = SF.trim(SF.col("cola")) + self.assertEqual("TRIM(cola)", col.sql()) + + def test_concat_ws(self): + col_str = SF.concat_ws("-", "cola", "colb") + self.assertEqual("CONCAT_WS('-', cola, colb)", col_str.sql()) + col = SF.concat_ws("-", SF.col("cola"), SF.col("colb")) + self.assertEqual("CONCAT_WS('-', cola, colb)", col.sql()) + + def test_decode(self): + col_str = SF.decode("cola", "US-ASCII") + self.assertEqual("DECODE(cola, 'US-ASCII')", col_str.sql()) + col = SF.decode(SF.col("cola"), "US-ASCII") + self.assertEqual("DECODE(cola, 'US-ASCII')", col.sql()) + + def test_encode(self): + col_str = SF.encode("cola", "US-ASCII") + self.assertEqual("ENCODE(cola, 'US-ASCII')", col_str.sql()) + col = SF.encode(SF.col("cola"), "US-ASCII") + self.assertEqual("ENCODE(cola, 'US-ASCII')", col.sql()) + + def test_format_number(self): + col_str = SF.format_number("cola", 4) + self.assertEqual("FORMAT_NUMBER(cola, 4)", col_str.sql()) + col = SF.format_number(SF.col("cola"), 4) + self.assertEqual("FORMAT_NUMBER(cola, 4)", col.sql()) + + def test_format_string(self): + col_str = SF.format_string("%d %s", "cola", "colb", "colc") + self.assertEqual("FORMAT_STRING('%d %s', cola, colb, colc)", col_str.sql()) + col = SF.format_string("%d %s", SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("FORMAT_STRING('%d %s', cola, colb, colc)", col.sql()) + + def test_instr(self): + col_str = SF.instr("cola", "test") + self.assertEqual("INSTR(cola, 'test')", col_str.sql()) + col = SF.instr(SF.col("cola"), "test") + self.assertEqual("INSTR(cola, 'test')", col.sql()) + + def test_overlay(self): + col_str = SF.overlay("cola", "colb", 3, 7) + self.assertEqual("OVERLAY(cola, colb, 3, 7)", col_str.sql()) + col = SF.overlay(SF.col("cola"), SF.col("colb"), SF.lit(3), SF.lit(7)) + self.assertEqual("OVERLAY(cola, colb, 3, 7)", col.sql()) + col_no_length = SF.overlay("cola", "colb", 3) + self.assertEqual("OVERLAY(cola, colb, 3)", col_no_length.sql()) + + def test_sentences(self): + col_str = SF.sentences("cola", SF.lit("en"), SF.lit("US")) + self.assertEqual("SENTENCES(cola, 'en', 'US')", col_str.sql()) + col = SF.sentences(SF.col("cola"), SF.lit("en"), SF.lit("US")) + self.assertEqual("SENTENCES(cola, 'en', 'US')", col.sql()) + col_no_country = SF.sentences("cola", SF.lit("en")) + self.assertEqual("SENTENCES(cola, 'en')", col_no_country.sql()) + col_no_lang = SF.sentences(SF.col("cola"), country=SF.lit("US")) + self.assertEqual("SENTENCES(cola, 'en', 'US')", col_no_lang.sql()) + col_defaults = SF.sentences(SF.col("cola")) + self.assertEqual("SENTENCES(cola)", col_defaults.sql()) + + def test_substring(self): + col_str = SF.substring("cola", 2, 3) + self.assertEqual("SUBSTRING(cola, 2, 3)", col_str.sql()) + col = SF.substring(SF.col("cola"), 2, 3) + self.assertEqual("SUBSTRING(cola, 2, 3)", col.sql()) + + def test_substring_index(self): + col_str = SF.substring_index("cola", ".", 2) + self.assertEqual("SUBSTRING_INDEX(cola, '.', 2)", col_str.sql()) + col = SF.substring_index(SF.col("cola"), ".", 2) + self.assertEqual("SUBSTRING_INDEX(cola, '.', 2)", col.sql()) + + def test_levenshtein(self): + col_str = SF.levenshtein("cola", "colb") + self.assertEqual("LEVENSHTEIN(cola, colb)", col_str.sql()) + col = SF.levenshtein(SF.col("cola"), SF.col("colb")) + self.assertEqual("LEVENSHTEIN(cola, colb)", col.sql()) + + def test_locate(self): + col_str = SF.locate("test", "cola", 3) + self.assertEqual("LOCATE('test', cola, 3)", col_str.sql()) + col = SF.locate("test", SF.col("cola"), 3) + self.assertEqual("LOCATE('test', cola, 3)", col.sql()) + col_no_pos = SF.locate("test", "cola") + self.assertEqual("LOCATE('test', cola)", col_no_pos.sql()) + + def test_lpad(self): + col_str = SF.lpad("cola", 3, "#") + self.assertEqual("LPAD(cola, 3, '#')", col_str.sql()) + col = SF.lpad(SF.col("cola"), 3, "#") + self.assertEqual("LPAD(cola, 3, '#')", col.sql()) + + def test_rpad(self): + col_str = SF.rpad("cola", 3, "#") + self.assertEqual("RPAD(cola, 3, '#')", col_str.sql()) + col = SF.rpad(SF.col("cola"), 3, "#") + self.assertEqual("RPAD(cola, 3, '#')", col.sql()) + + def test_repeat(self): + col_str = SF.repeat("cola", 3) + self.assertEqual("REPEAT(cola, 3)", col_str.sql()) + col = SF.repeat(SF.col("cola"), 3) + self.assertEqual("REPEAT(cola, 3)", col.sql()) + + def test_split(self): + col_str = SF.split("cola", "[ABC]", 3) + self.assertEqual("SPLIT(cola, '[ABC]', 3)", col_str.sql()) + col = SF.split(SF.col("cola"), "[ABC]", 3) + self.assertEqual("SPLIT(cola, '[ABC]', 3)", col.sql()) + col_no_limit = SF.split("cola", "[ABC]") + self.assertEqual("SPLIT(cola, '[ABC]')", col_no_limit.sql()) + + def test_regexp_extract(self): + col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql()) + col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql()) + col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql()) + + def test_regexp_replace(self): + col_str = SF.regexp_replace("cola", r"(\d+)", "--") + self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql()) + col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") + self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql()) + + def test_initcap(self): + col_str = SF.initcap("cola") + self.assertEqual("INITCAP(cola)", col_str.sql()) + col = SF.initcap(SF.col("cola")) + self.assertEqual("INITCAP(cola)", col.sql()) + + def test_soundex(self): + col_str = SF.soundex("cola") + self.assertEqual("SOUNDEX(cola)", col_str.sql()) + col = SF.soundex(SF.col("cola")) + self.assertEqual("SOUNDEX(cola)", col.sql()) + + def test_bin(self): + col_str = SF.bin("cola") + self.assertEqual("BIN(cola)", col_str.sql()) + col = SF.bin(SF.col("cola")) + self.assertEqual("BIN(cola)", col.sql()) + + def test_hex(self): + col_str = SF.hex("cola") + self.assertEqual("HEX(cola)", col_str.sql()) + col = SF.hex(SF.col("cola")) + self.assertEqual("HEX(cola)", col.sql()) + + def test_unhex(self): + col_str = SF.unhex("cola") + self.assertEqual("UNHEX(cola)", col_str.sql()) + col = SF.unhex(SF.col("cola")) + self.assertEqual("UNHEX(cola)", col.sql()) + + def test_length(self): + col_str = SF.length("cola") + self.assertEqual("LENGTH(cola)", col_str.sql()) + col = SF.length(SF.col("cola")) + self.assertEqual("LENGTH(cola)", col.sql()) + + def test_octet_length(self): + col_str = SF.octet_length("cola") + self.assertEqual("OCTET_LENGTH(cola)", col_str.sql()) + col = SF.octet_length(SF.col("cola")) + self.assertEqual("OCTET_LENGTH(cola)", col.sql()) + + def test_bit_length(self): + col_str = SF.bit_length("cola") + self.assertEqual("BIT_LENGTH(cola)", col_str.sql()) + col = SF.bit_length(SF.col("cola")) + self.assertEqual("BIT_LENGTH(cola)", col.sql()) + + def test_translate(self): + col_str = SF.translate("cola", "abc", "xyz") + self.assertEqual("TRANSLATE(cola, 'abc', 'xyz')", col_str.sql()) + col = SF.translate(SF.col("cola"), "abc", "xyz") + self.assertEqual("TRANSLATE(cola, 'abc', 'xyz')", col.sql()) + + def test_array(self): + col_str = SF.array("cola", "colb") + self.assertEqual("ARRAY(cola, colb)", col_str.sql()) + col = SF.array(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY(cola, colb)", col.sql()) + col_array = SF.array(["cola", "colb"]) + self.assertEqual("ARRAY(cola, colb)", col_array.sql()) + + def test_create_map(self): + col_str = SF.create_map("keya", "valuea", "keyb", "valueb") + self.assertEqual("MAP(keya, valuea, keyb, valueb)", col_str.sql()) + col = SF.create_map(SF.col("keya"), SF.col("valuea"), SF.col("keyb"), SF.col("valueb")) + self.assertEqual("MAP(keya, valuea, keyb, valueb)", col.sql()) + col_array = SF.create_map(["keya", "valuea", "keyb", "valueb"]) + self.assertEqual("MAP(keya, valuea, keyb, valueb)", col_array.sql()) + + def test_map_from_arrays(self): + col_str = SF.map_from_arrays("cola", "colb") + self.assertEqual("MAP_FROM_ARRAYS(cola, colb)", col_str.sql()) + col = SF.map_from_arrays(SF.col("cola"), SF.col("colb")) + self.assertEqual("MAP_FROM_ARRAYS(cola, colb)", col.sql()) + + def test_array_contains(self): + col_str = SF.array_contains("cola", "test") + self.assertEqual("ARRAY_CONTAINS(cola, 'test')", col_str.sql()) + col = SF.array_contains(SF.col("cola"), "test") + self.assertEqual("ARRAY_CONTAINS(cola, 'test')", col.sql()) + col_as_value = SF.array_contains("cola", SF.col("colb")) + self.assertEqual("ARRAY_CONTAINS(cola, colb)", col_as_value.sql()) + + def test_arrays_overlap(self): + col_str = SF.arrays_overlap("cola", "colb") + self.assertEqual("ARRAYS_OVERLAP(cola, colb)", col_str.sql()) + col = SF.arrays_overlap(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAYS_OVERLAP(cola, colb)", col.sql()) + + def test_slice(self): + col_str = SF.slice("cola", SF.col("colb"), SF.col("colc")) + self.assertEqual("SLICE(cola, colb, colc)", col_str.sql()) + col = SF.slice(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("SLICE(cola, colb, colc)", col.sql()) + col_ints = SF.slice("cola", 1, 10) + self.assertEqual("SLICE(cola, 1, 10)", col_ints.sql()) + + def test_array_join(self): + col_str = SF.array_join("cola", "-", "NULL_REPLACEMENT") + self.assertEqual("ARRAY_JOIN(cola, '-', 'NULL_REPLACEMENT')", col_str.sql()) + col = SF.array_join(SF.col("cola"), "-", "NULL_REPLACEMENT") + self.assertEqual("ARRAY_JOIN(cola, '-', 'NULL_REPLACEMENT')", col.sql()) + col_no_replacement = SF.array_join("cola", "-") + self.assertEqual("ARRAY_JOIN(cola, '-')", col_no_replacement.sql()) + + def test_concat(self): + col_str = SF.concat("cola", "colb") + self.assertEqual("CONCAT(cola, colb)", col_str.sql()) + col = SF.concat(SF.col("cola"), SF.col("colb")) + self.assertEqual("CONCAT(cola, colb)", col.sql()) + col_single = SF.concat("cola") + self.assertEqual("CONCAT(cola)", col_single.sql()) + + def test_array_position(self): + col_str = SF.array_position("cola", SF.col("colb")) + self.assertEqual("ARRAY_POSITION(cola, colb)", col_str.sql()) + col = SF.array_position(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_POSITION(cola, colb)", col.sql()) + col_lit = SF.array_position("cola", "test") + self.assertEqual("ARRAY_POSITION(cola, 'test')", col_lit) + + def test_element_at(self): + col_str = SF.element_at("cola", SF.col("colb")) + self.assertEqual("ELEMENT_AT(cola, colb)", col_str.sql()) + col = SF.element_at(SF.col("cola"), SF.col("colb")) + self.assertEqual("ELEMENT_AT(cola, colb)", col.sql()) + col_lit = SF.element_at("cola", "test") + self.assertEqual("ELEMENT_AT(cola, 'test')", col_lit) + + def test_array_remove(self): + col_str = SF.array_remove("cola", SF.col("colb")) + self.assertEqual("ARRAY_REMOVE(cola, colb)", col_str.sql()) + col = SF.array_remove(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_REMOVE(cola, colb)", col.sql()) + col_lit = SF.array_remove("cola", "test") + self.assertEqual("ARRAY_REMOVE(cola, 'test')", col_lit) + + def test_array_distinct(self): + col_str = SF.array_distinct("cola") + self.assertEqual("ARRAY_DISTINCT(cola)", col_str.sql()) + col = SF.array_distinct(SF.col("cola")) + self.assertEqual("ARRAY_DISTINCT(cola)", col.sql()) + + def test_array_intersect(self): + col_str = SF.array_intersect("cola", "colb") + self.assertEqual("ARRAY_INTERSECT(cola, colb)", col_str.sql()) + col = SF.array_intersect(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_INTERSECT(cola, colb)", col.sql()) + + def test_array_union(self): + col_str = SF.array_union("cola", "colb") + self.assertEqual("ARRAY_UNION(cola, colb)", col_str.sql()) + col = SF.array_union(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_UNION(cola, colb)", col.sql()) + + def test_array_except(self): + col_str = SF.array_except("cola", "colb") + self.assertEqual("ARRAY_EXCEPT(cola, colb)", col_str.sql()) + col = SF.array_except(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_EXCEPT(cola, colb)", col.sql()) + + def test_explode(self): + col_str = SF.explode("cola") + self.assertEqual("EXPLODE(cola)", col_str.sql()) + col = SF.explode(SF.col("cola")) + self.assertEqual("EXPLODE(cola)", col.sql()) + + def test_pos_explode(self): + col_str = SF.posexplode("cola") + self.assertEqual("POSEXPLODE(cola)", col_str.sql()) + col = SF.posexplode(SF.col("cola")) + self.assertEqual("POSEXPLODE(cola)", col.sql()) + + def test_explode_outer(self): + col_str = SF.explode_outer("cola") + self.assertEqual("EXPLODE_OUTER(cola)", col_str.sql()) + col = SF.explode_outer(SF.col("cola")) + self.assertEqual("EXPLODE_OUTER(cola)", col.sql()) + + def test_posexplode_outer(self): + col_str = SF.posexplode_outer("cola") + self.assertEqual("POSEXPLODE_OUTER(cola)", col_str.sql()) + col = SF.posexplode_outer(SF.col("cola")) + self.assertEqual("POSEXPLODE_OUTER(cola)", col.sql()) + + def test_get_json_object(self): + col_str = SF.get_json_object("cola", "$.f1") + self.assertEqual("GET_JSON_OBJECT(cola, '$.f1')", col_str.sql()) + col = SF.get_json_object(SF.col("cola"), "$.f1") + self.assertEqual("GET_JSON_OBJECT(cola, '$.f1')", col.sql()) + + def test_json_tuple(self): + col_str = SF.json_tuple("cola", "f1", "f2") + self.assertEqual("JSON_TUPLE(cola, 'f1', 'f2')", col_str.sql()) + col = SF.json_tuple(SF.col("cola"), "f1", "f2") + self.assertEqual("JSON_TUPLE(cola, 'f1', 'f2')", col.sql()) + + def test_from_json(self): + col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.from_json("cola", "cola INT") + self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql()) + + def test_to_json(self): + col_str = SF.to_json("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.to_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.to_json("cola") + self.assertEqual("TO_JSON(cola)", col_no_option.sql()) + + def test_schema_of_json(self): + col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.schema_of_json("cola") + self.assertEqual("SCHEMA_OF_JSON(cola)", col_no_option.sql()) + + def test_schema_of_csv(self): + col_str = SF.schema_of_csv("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.schema_of_csv(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.schema_of_csv("cola") + self.assertEqual("SCHEMA_OF_CSV(cola)", col_no_option.sql()) + + def test_to_csv(self): + col_str = SF.to_csv("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.to_csv(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.to_csv("cola") + self.assertEqual("TO_CSV(cola)", col_no_option.sql()) + + def test_size(self): + col_str = SF.size("cola") + self.assertEqual("SIZE(cola)", col_str.sql()) + col = SF.size(SF.col("cola")) + self.assertEqual("SIZE(cola)", col.sql()) + + def test_array_min(self): + col_str = SF.array_min("cola") + self.assertEqual("ARRAY_MIN(cola)", col_str.sql()) + col = SF.array_min(SF.col("cola")) + self.assertEqual("ARRAY_MIN(cola)", col.sql()) + + def test_array_max(self): + col_str = SF.array_max("cola") + self.assertEqual("ARRAY_MAX(cola)", col_str.sql()) + col = SF.array_max(SF.col("cola")) + self.assertEqual("ARRAY_MAX(cola)", col.sql()) + + def test_sort_array(self): + col_str = SF.sort_array("cola", False) + self.assertEqual("SORT_ARRAY(cola, FALSE)", col_str.sql()) + col = SF.sort_array(SF.col("cola"), False) + self.assertEqual("SORT_ARRAY(cola, FALSE)", col.sql()) + col_no_sort = SF.sort_array("cola") + self.assertEqual("SORT_ARRAY(cola)", col_no_sort.sql()) + + def test_array_sort(self): + col_str = SF.array_sort("cola") + self.assertEqual("ARRAY_SORT(cola)", col_str.sql()) + col = SF.array_sort(SF.col("cola")) + self.assertEqual("ARRAY_SORT(cola)", col.sql()) + + def test_reverse(self): + col_str = SF.reverse("cola") + self.assertEqual("REVERSE(cola)", col_str.sql()) + col = SF.reverse(SF.col("cola")) + self.assertEqual("REVERSE(cola)", col.sql()) + + def test_flatten(self): + col_str = SF.flatten("cola") + self.assertEqual("FLATTEN(cola)", col_str.sql()) + col = SF.flatten(SF.col("cola")) + self.assertEqual("FLATTEN(cola)", col.sql()) + + def test_map_keys(self): + col_str = SF.map_keys("cola") + self.assertEqual("MAP_KEYS(cola)", col_str.sql()) + col = SF.map_keys(SF.col("cola")) + self.assertEqual("MAP_KEYS(cola)", col.sql()) + + def test_map_values(self): + col_str = SF.map_values("cola") + self.assertEqual("MAP_VALUES(cola)", col_str.sql()) + col = SF.map_values(SF.col("cola")) + self.assertEqual("MAP_VALUES(cola)", col.sql()) + + def test_map_entries(self): + col_str = SF.map_entries("cola") + self.assertEqual("MAP_ENTRIES(cola)", col_str.sql()) + col = SF.map_entries(SF.col("cola")) + self.assertEqual("MAP_ENTRIES(cola)", col.sql()) + + def test_map_from_entries(self): + col_str = SF.map_from_entries("cola") + self.assertEqual("MAP_FROM_ENTRIES(cola)", col_str.sql()) + col = SF.map_from_entries(SF.col("cola")) + self.assertEqual("MAP_FROM_ENTRIES(cola)", col.sql()) + + def test_array_repeat(self): + col_str = SF.array_repeat("cola", 2) + self.assertEqual("ARRAY_REPEAT(cola, 2)", col_str.sql()) + col = SF.array_repeat(SF.col("cola"), 2) + self.assertEqual("ARRAY_REPEAT(cola, 2)", col.sql()) + + def test_array_zip(self): + col_str = SF.array_zip("cola", "colb") + self.assertEqual("ARRAY_ZIP(cola, colb)", col_str.sql()) + col = SF.array_zip(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_ZIP(cola, colb)", col.sql()) + col_single = SF.array_zip("cola") + self.assertEqual("ARRAY_ZIP(cola)", col_single.sql()) + + def test_map_concat(self): + col_str = SF.map_concat("cola", "colb") + self.assertEqual("MAP_CONCAT(cola, colb)", col_str.sql()) + col = SF.map_concat(SF.col("cola"), SF.col("colb")) + self.assertEqual("MAP_CONCAT(cola, colb)", col.sql()) + col_single = SF.map_concat("cola") + self.assertEqual("MAP_CONCAT(cola)", col_single.sql()) + + def test_sequence(self): + col_str = SF.sequence("cola", "colb", "colc") + self.assertEqual("SEQUENCE(cola, colb, colc)", col_str.sql()) + col = SF.sequence(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("SEQUENCE(cola, colb, colc)", col.sql()) + col_no_step = SF.sequence("cola", "colb") + self.assertEqual("SEQUENCE(cola, colb)", col_no_step.sql()) + + def test_from_csv(self): + col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.from_csv("cola", "cola INT") + self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql()) + + def test_aggregate(self): + col_str = SF.aggregate("cola", SF.lit(0), lambda acc, x: acc + x, lambda acc: acc * 2) + self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x, acc -> acc * 2)", col_str.sql()) + col = SF.aggregate(SF.col("cola"), SF.lit(0), lambda acc, x: acc + x, lambda acc: acc * 2) + self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x, acc -> acc * 2)", col.sql()) + col_no_finish = SF.aggregate("cola", SF.lit(0), lambda acc, x: acc + x) + self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x)", col_no_finish.sql()) + col_custom_names = SF.aggregate( + "cola", + SF.lit(0), + lambda accumulator, target: accumulator + target, + lambda accumulator: accumulator * 2, + "accumulator", + "target", + ) + self.assertEqual( + "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)", + col_custom_names.sql(), + ) + + def test_transform(self): + col_str = SF.transform("cola", lambda x: x * 2) + self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql()) + col = SF.transform(SF.col("cola"), lambda x, i: x * i) + self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) + col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count") + + self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) + + def test_exists(self): + col_str = SF.exists("cola", lambda x: x % 2 == 0) + self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql()) + col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0) + self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql()) + col_custom_name = SF.exists("cola", lambda target: target > 0, "target") + self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql()) + + def test_forall(self): + col_str = SF.forall("cola", lambda x: x.rlike("foo")) + self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql()) + col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo")) + self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql()) + col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target") + self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql()) + + def test_filter(self): + col_str = SF.filter("cola", lambda x: SF.month(SF.to_date(x)) > SF.lit(6)) + self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) + col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) + self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) + col_custom_names = SF.filter( + "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count" + ) + + self.assertEqual( + "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() + ) + + def test_zip_with(self): + col_str = SF.zip_with("cola", "colb", lambda x, y: SF.concat_ws("_", x, y)) + self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql()) + col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) + self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) + col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r") + self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) + + def test_transform_keys(self): + col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k)) + self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql()) + col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k)) + self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql()) + col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_") + self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql()) + + def test_transform_values(self): + col_str = SF.transform_values("cola", lambda k, v: SF.upper(v)) + self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql()) + col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) + self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) + col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value") + self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) + + def test_map_filter(self): + col_str = SF.map_filter("cola", lambda k, v: k > v) + self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql()) + col = SF.map_filter(SF.col("cola"), lambda k, v: k > v) + self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql()) + col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value") + self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql()) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py new file mode 100644 index 0000000..158dcec --- /dev/null +++ b/tests/dataframe/unit/test_session.py @@ -0,0 +1,114 @@ +from unittest import mock + +import sqlglot +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.schema import MappingSchema +from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator + + +class TestDataframeSession(DataFrameSQLValidator): + def test_cdf_one_row(self): + df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"]) + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_cdf_multiple_rows(self): + df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"]) + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_cdf_no_schema(self): + df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) + expected = ( + "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" + ) + self.compare_sql(df, expected) + + def test_cdf_row_mixed_primitives(self): + df = self.spark.createDataFrame([[1, 10.1, "test", False, None]]) + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" + self.compare_sql(df, expected) + + def test_cdf_dict_rows(self): + df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}]) + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_cdf_str_schema(self): + df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") + expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_typed_schema_basic(self): + schema = types.StructType( + [ + types.StructField("cola", types.IntegerType()), + types.StructField("colb", types.StringType()), + ] + ) + df = self.spark.createDataFrame([[1, "test"]], schema) + expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_typed_schema_nested(self): + schema = types.StructType( + [ + types.StructField( + "cola", + types.StructType( + [ + types.StructField("sub_cola", types.IntegerType()), + types.StructField("sub_colb", types.StringType()), + ] + ), + ) + ] + ) + df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema) + expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)" + self.compare_sql(df, expected) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_select_only(self): + # TODO: Do exact matches once CTE names are deterministic + query = "SELECT cola, colb FROM table" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query) + self.assertIn( + "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False) + ) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_with_aggs(self): + # TODO: Do exact matches once CTE names are deterministic + query = "SELECT cola, colb FROM table" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb")) + result = df.sql(pretty=False, optimize=False)[0] + self.assertIn("SELECT cola, colb FROM table", result) + self.assertIn("SUM(colb)", result) + self.assertIn("GROUP BY cola", result) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_create(self): + query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query) + expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" + self.compare_sql(df, expected) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_insert(self): + query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query) + expected = ( + "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" + ) + self.compare_sql(df, expected) + + def test_session_create_builder_patterns(self): + spark = SparkSession() + self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark) diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py new file mode 100644 index 0000000..1f6c5dc --- /dev/null +++ b/tests/dataframe/unit/test_types.py @@ -0,0 +1,70 @@ +import unittest + +from sqlglot.dataframe.sql import types + + +class TestDataframeTypes(unittest.TestCase): + def test_string(self): + self.assertEqual("string", types.StringType().simpleString()) + + def test_char(self): + self.assertEqual("char(100)", types.CharType(100).simpleString()) + + def test_varchar(self): + self.assertEqual("varchar(65)", types.VarcharType(65).simpleString()) + + def test_binary(self): + self.assertEqual("binary", types.BinaryType().simpleString()) + + def test_boolean(self): + self.assertEqual("boolean", types.BooleanType().simpleString()) + + def test_date(self): + self.assertEqual("date", types.DateType().simpleString()) + + def test_timestamp(self): + self.assertEqual("timestamp", types.TimestampType().simpleString()) + + def test_timestamp_ntz(self): + self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString()) + + def test_decimal(self): + self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString()) + + def test_double(self): + self.assertEqual("double", types.DoubleType().simpleString()) + + def test_float(self): + self.assertEqual("float", types.FloatType().simpleString()) + + def test_byte(self): + self.assertEqual("tinyint", types.ByteType().simpleString()) + + def test_integer(self): + self.assertEqual("int", types.IntegerType().simpleString()) + + def test_long(self): + self.assertEqual("bigint", types.LongType().simpleString()) + + def test_short(self): + self.assertEqual("smallint", types.ShortType().simpleString()) + + def test_array(self): + self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString()) + + def test_map(self): + self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString()) + + def test_struct_field(self): + self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString()) + + def test_struct_type(self): + self.assertEqual( + "struct<cola:int, colb:string>", + types.StructType( + [ + types.StructField("cola", types.IntegerType()), + types.StructField("colb", types.StringType()), + ] + ).simpleString(), + ) diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py new file mode 100644 index 0000000..eea4582 --- /dev/null +++ b/tests/dataframe/unit/test_window.py @@ -0,0 +1,60 @@ +import unittest + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.window import Window, WindowSpec + + +class TestDataframeWindow(unittest.TestCase): + def test_window_spec_partition_by(self): + partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb")) + self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql()) + + def test_window_spec_order_by(self): + order_by = WindowSpec().orderBy("cola", "colb") + self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql()) + + def test_window_spec_rows_between(self): + rows_between = WindowSpec().rowsBetween(3, 5) + self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + + def test_window_spec_range_between(self): + range_between = WindowSpec().rangeBetween(3, 5) + self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + + def test_window_partition_by(self): + partition_by = Window.partitionBy(F.col("cola"), F.col("colb")) + self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql()) + + def test_window_order_by(self): + order_by = Window.orderBy("cola", "colb") + self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql()) + + def test_window_rows_between(self): + rows_between = Window.rowsBetween(3, 5) + self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + + def test_window_range_between(self): + range_between = Window.rangeBetween(3, 5) + self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + + def test_window_rows_unbounded(self): + rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) + self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql()) + rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) + self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql()) + rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql() + ) + + def test_window_range_unbounded(self): + range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql() + ) + range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) + self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql()) + range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql() + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index a1e1262..e1524e9 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -694,29 +694,6 @@ class TestDialect(Validator): }, ) - # https://dev.mysql.com/doc/refman/8.0/en/join.html - # https://www.postgresql.org/docs/current/queries-table-expressions.html - def test_joined_tables(self): - self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)") - self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)") - self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)") - self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)") - - self.validate_all( - "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", - write={ - "postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", - "mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", - }, - ) - self.validate_all( - "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", - write={ - "postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", - "mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", - }, - ) - def test_lateral_subquery(self): self.validate_identity( "SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art" @@ -856,7 +833,7 @@ class TestDialect(Validator): "postgres": "x ILIKE '%y'", "presto": "LOWER(x) LIKE '%y'", "snowflake": "x ILIKE '%y'", - "spark": "LOWER(x) LIKE '%y'", + "spark": "x ILIKE '%y'", "sqlite": "LOWER(x) LIKE '%y'", "starrocks": "LOWER(x) LIKE '%y'", "trino": "LOWER(x) LIKE '%y'", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 298b3e9..625156b 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -48,7 +48,7 @@ class TestDuckDB(Validator): self.validate_all( "STRPTIME(x, '%y-%-m')", write={ - "bigquery": "STR_TO_TIME(x, '%y-%-m')", + "bigquery": "PARSE_TIMESTAMP('%y-%m', x)", "duckdb": "STRPTIME(x, '%y-%-m')", "presto": "DATE_PARSE(x, '%y-%c')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)", @@ -63,6 +63,16 @@ class TestDuckDB(Validator): "hive": "CAST(x AS TIMESTAMP)", }, ) + self.validate_all( + "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')", + write={ + "bigquery": "PARSE_TIMESTAMP('%m/%d/%y %I:%M %p', x)", + "duckdb": "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')", + "presto": "DATE_PARSE(x, '%c/%e/%y %l:%i %p')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'M/d/yy h:mm a')) AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')", + }, + ) def test_duckdb(self): self.validate_all( @@ -268,6 +278,17 @@ class TestDuckDB(Validator): "spark": "MONTH('2021-03-01')", }, ) + self.validate_all( + "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", + write={ + "duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", + "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])", + "hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", + "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", + "snowflake": "ARRAY_CAT([1, 2], [3, 4])", + "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + }, + ) with self.assertRaises(UnsupportedError): transpile( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 723e27c..a25871c 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -31,6 +31,24 @@ class TestMySQL(Validator): "mysql": "_utf8mb4 'hola'", }, ) + self.validate_all( + "N 'some text'", + read={ + "mysql": "N'some text'", + }, + write={ + "mysql": "N 'some text'", + }, + ) + self.validate_all( + "_latin1 x'4D7953514C'", + read={ + "mysql": "_latin1 X'4D7953514C'", + }, + write={ + "mysql": "_latin1 x'4D7953514C'", + }, + ) def test_hexadecimal_literal(self): self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 4b8f3c3..35141e2 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -69,6 +69,8 @@ class TestPostgres(Validator): self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") self.validate_identity("COMMENT ON TABLE mytable IS 'this'") + self.validate_identity("SELECT e'\\xDEADBEEF'") + self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", @@ -204,3 +206,11 @@ class TestPostgres(Validator): """'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""", write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""}, ) + self.validate_all( + "SELECT $$a$$", + write={"postgres": "SELECT 'a'"}, + ) + self.validate_all( + "SELECT $$Dianne's horse$$", + write={"postgres": "SELECT 'Dianne''s horse'"}, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 10c9d35..098ad2b 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -321,7 +321,7 @@ class TestPresto(Validator): "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", - "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo", }, ) self.validate_all( @@ -329,7 +329,7 @@ class TestPresto(Validator): write={ "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", "hive": UnsupportedError, - "spark": UnsupportedError, + "spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 8a33e2d..159b643 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -65,7 +65,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP('2013-04-05 01:02:03')", write={ - "bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')", + "bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')", "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')", }, @@ -73,16 +73,17 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", read={ - "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", "duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", }, write={ - "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", }, ) + self.validate_all( "SELECT IFF(TRUE, 'true', 'false')", write={ @@ -240,11 +241,25 @@ class TestSnowflake(Validator): }, ) self.validate_all( - "SELECT DATE_PART(month FROM a::DATETIME)", + "SELECT DATE_PART(month, a::DATETIME)", write={ "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", }, ) + self.validate_all( + "SELECT DATE_PART(epoch_second, foo) as ddate from table_name", + write={ + "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name", + "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name", + }, + ) + self.validate_all( + "SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name", + write={ + "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name", + "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name", + }, + ) def test_semi_structured_types(self): self.validate_identity("SELECT CAST(a AS VARIANT)") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b061784..9a6bc36 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -45,3 +45,29 @@ class TestTSQL(Validator): "tsql": "CAST(x AS DATETIME2)", }, ) + + def test_charindex(self): + self.validate_all( + "CHARINDEX(x, y, 9)", + write={ + "spark": "LOCATE(x, y, 9)", + }, + ) + self.validate_all( + "CHARINDEX(x, y)", + write={ + "spark": "LOCATE(x, y)", + }, + ) + self.validate_all( + "CHARINDEX('sub', 'testsubstring', 3)", + write={ + "spark": "LOCATE('sub', 'testsubstring', 3)", + }, + ) + self.validate_all( + "CHARINDEX('sub', 'testsubstring')", + write={ + "spark": "LOCATE('sub', 'testsubstring')", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 57e51e0..67e4cab 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -513,6 +513,8 @@ ALTER TYPE electronic_mail RENAME TO email ANALYZE a.y DELETE FROM x WHERE y > 1 DELETE FROM y +DELETE FROM event USING sales WHERE event.eventid = sales.eventid +DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid DROP TABLE a DROP TABLE a.b DROP TABLE IF EXISTS a @@ -563,3 +565,8 @@ WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT ((SELECT 1) + 1) SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES +SELECT * FROM (table1 AS t1 LEFT JOIN table2 AS t2 ON 1 = 1) +SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1) +SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3) +SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) +SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index a82e1ed..4a3ad4b 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -287,3 +287,27 @@ SELECT FROM t1; SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x; + +# title: Values Test +# dialect: spark +WITH t1 AS ( + SELECT + a1.cola + FROM + VALUES (1) AS a1(cola) +), t2 AS ( + SELECT + a2.cola + FROM + VALUES (1) AS a2(cola) +) +SELECT /*+ BROADCAST(t2) */ + t1.cola, + t2.cola, +FROM + t1 + JOIN + t2 + ON + t1.cola = t2.cola; +SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola; diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql index ef591ec..dd318a2 100644 --- a/tests/fixtures/optimizer/pushdown_predicates.sql +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -33,3 +33,6 @@ SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y. with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; + +WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a; +WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index b03ffab..ba4bf45 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -22,6 +22,9 @@ SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_ SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x); SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0"; +WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1; +WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1; + SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x); SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 83a3bf8..858f232 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -72,6 +72,9 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; SELECT a FROM x ORDER BY b; SELECT x.a AS a FROM x AS x ORDER BY x.b; +SELECT SUM(a) AS a FROM x ORDER BY SUM(a); +SELECT SUM(x.a) AS a FROM x AS x ORDER BY SUM(x.a); + # dialect: bigquery SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1; SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1; diff --git a/tests/helpers.py b/tests/helpers.py index 2d200f6..dabaf1c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -53,6 +53,8 @@ def string_to_bool(string): return string and string.lower() in ("true", "1") +SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower()) + TPCH_SCHEMA = { "lineitem": { "l_orderkey": "uint64", diff --git a/tests/test_executor.py b/tests/test_executor.py index c5841d3..ef1a706 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -7,11 +7,17 @@ from pandas.testing import assert_frame_equal from sqlglot import exp, parse_one from sqlglot.executor import execute from sqlglot.executor.python import Python -from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs +from tests.helpers import ( + FIXTURES_DIR, + SKIP_INTEGRATION, + TPCH_SCHEMA, + load_sql_fixture_pairs, +) DIR = FIXTURES_DIR + "/optimizer/tpc-h/" +@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") class TestExecutor(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 9ad2bf5..79b4ee5 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -123,13 +123,16 @@ class TestExpressions(unittest.TestCase): self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c") self.assertEqual(exp.table_name("a.b.c"), "a.b.c") + def test_table(self): + self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table)) + def test_replace_tables(self): self.assertEqual( exp.replace_tables( - parse_one("select * from a join b join c.a join d.a join e.a"), + parse_one("select * from a AS a join b join c.a join d.a join e.a"), {"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"}, ).sql(), - 'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a', + "SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a", ) def test_named_selects(self): @@ -495,11 +498,15 @@ class TestExpressions(unittest.TestCase): self.assertEqual(exp.convert(value).sql(), expected) def test_annotation_alias(self): - expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo") + sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo" + expression = parse_one(sql) self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "c", "D"], ) + self.assertEqual(expression.sql(), sql) + self.assertEqual(expression.expressions[2].name, "comment") + self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D") def test_to_table(self): table_only = exp.to_table("table_name") @@ -514,6 +521,18 @@ class TestExpressions(unittest.TestCase): self.assertEqual(catalog_db_and_table.name, "table_name") self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db")) self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) + with self.assertRaises(ValueError): + exp.to_table(1) + + def test_to_column(self): + column_only = exp.to_column("column_name") + self.assertEqual(column_only.name, "column_name") + self.assertIsNone(column_only.args.get("table")) + table_and_column = exp.to_column("table_name.column_name") + self.assertEqual(table_and_column.name, "column_name") + self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name")) + with self.assertRaises(ValueError): + exp.to_column(1) def test_union(self): expression = parse_one("SELECT cola, colb UNION SELECT colx, coly") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a67e9db..3b5990f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,11 +5,11 @@ import duckdb from pandas.testing import assert_frame_equal import sqlglot -from sqlglot import exp, optimizer, parse_one, table +from sqlglot import exp, optimizer, parse_one from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types -from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope +from sqlglot.schema import MappingSchema from tests.helpers import ( TPCH_SCHEMA, load_sql_fixture_pairs, @@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase): CREATE TABLE x (a INT, b INT); CREATE TABLE y (b INT, c INT); CREATE TABLE z (b INT, c INT); - + INSERT INTO x VALUES (1, 1); INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (3, 3); INSERT INTO x VALUES (null, null); - + INSERT INTO y VALUES (2, 2); INSERT INTO y VALUES (2, 2); INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (4, 4); INSERT INTO y VALUES (null, null); - + INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (4, 4); @@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase): with self.subTest(title): self.assertEqual( - optimized.sql(pretty=pretty, dialect=dialect), expected, + optimized.sql(pretty=pretty, dialect=dialect), ) should_execute = meta.get("execute") @@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase): def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) - def test_schema(self): - schema = ensure_schema( - { - "x": { - "a": "uint64", - } - } - ) - self.assertEqual( - schema.column_names( - table( - "x", - ) - ), - ["a"], - ) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x2")) - - schema = ensure_schema( - { - "db": { - "x": { - "a": "uint64", - } - } - } - ) - self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - - schema = ensure_schema( - { - "c": { - "db": { - "x": { - "a": "uint64", - } - } - } - } - ) - self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c2")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - - schema = ensure_schema( - MappingSchema( - { - "x": { - "a": "uint64", - } - } - ) - ) - self.assertEqual(schema.column_names(table("x")), ["a"]) - - with self.assertRaises(OptimizeError): - ensure_schema({}) - def test_file_schema(self): expression = parse_one( """ @@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') SELECT x.b FROM x ), r AS ( SELECT y.b FROM y + ), z as ( + SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb) ) SELECT r.b, @@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = parse_one(sql) for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): - self.assertEqual(len(scopes), 5) + self.assertEqual(len(scopes), 7) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y") - self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) - - self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) - self.assertEqual(len(scopes[4].columns), 6) - self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) - self.assertEqual(scopes[4].source_columns("q"), []) - self.assertEqual(len(scopes[4].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") + self.assertEqual( + scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" + ) + self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") + self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"}) + self.assertEqual(len(scopes[6].columns), 6) + self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"}) + self.assertEqual(scopes[6].source_columns("q"), []) + self.assertEqual(len(scopes[6].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"}) self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") diff --git a/tests/test_parser.py b/tests/test_parser.py index 4e86516..9afeae6 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -81,7 +81,7 @@ class TestParser(unittest.TestCase): self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint) - default = Parser() + default = Parser(error_level=ErrorLevel.RAISE) self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint) default.expression(exp.Hint, y="") default.expression(exp.Hint) @@ -139,12 +139,12 @@ class TestParser(unittest.TestCase): ) assert expression.expressions[0].name == "annotation1" - assert expression.expressions[1].name == "annotation2:testing " + assert expression.expressions[1].name == "annotation2:testing" assert expression.expressions[2].name == "test#annotation" assert expression.expressions[3].name == "annotation3" assert expression.expressions[4].name == "annotation4" assert expression.expressions[5].name == "" - assert expression.expressions[6].name == " space" + assert expression.expressions[6].name == "space" def test_pretty_config_override(self): self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..bab97d8 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,290 @@ +import unittest + +from sqlglot import table +from sqlglot.dataframe.sql import types as df_types +from sqlglot.schema import MappingSchema, ensure_schema + + +class TestSchema(unittest.TestCase): + def test_schema(self): + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual( + schema.column_names( + table( + "x", + ) + ), + ["a"], + ) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x2")) + + with self.assertRaises(ValueError): + schema.add_table(table("y", db="db"), {"b": "string"}) + with self.assertRaises(ValueError): + schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + + schema.add_table(table("y"), {"b": "string"}) + schema_with_y = { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + } + self.assertEqual(schema.schema, schema_with_y) + + new_schema = schema.copy() + new_schema.add_table(table("z"), {"c": "string"}) + self.assertEqual(schema.schema, schema_with_y) + self.assertEqual( + new_schema.schema, + { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + "z": { + "c": "string", + }, + }, + ) + schema.add_table(table("m"), {"d": "string"}) + schema.add_table(table("n"), {"e": "string"}) + schema_with_m_n = { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + "m": { + "d": "string", + }, + "n": { + "e": "string", + }, + } + self.assertEqual(schema.schema, schema_with_m_n) + new_schema = schema.copy() + new_schema.add_table(table("o"), {"f": "string"}) + new_schema.add_table(table("p"), {"g": "string"}) + self.assertEqual(schema.schema, schema_with_m_n) + self.assertEqual( + new_schema.schema, + { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + "m": { + "d": "string", + }, + "n": { + "e": "string", + }, + "o": { + "f": "string", + }, + "p": { + "g": "string", + }, + }, + ) + + schema = ensure_schema( + { + "db": { + "x": { + "a": "uint64", + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + with self.assertRaises(ValueError): + schema.add_table(table("y"), {"b": "string"}) + with self.assertRaises(ValueError): + schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + + schema.add_table(table("y", db="db"), {"b": "string"}) + self.assertEqual( + schema.schema, + { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + } + }, + ) + + schema = ensure_schema( + { + "c": { + "db": { + "x": { + "a": "uint64", + } + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c2")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + with self.assertRaises(ValueError): + schema.add_table(table("x"), {"b": "string"}) + with self.assertRaises(ValueError): + schema.add_table(table("x", db="db"), {"b": "string"}) + + schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"}) + self.assertEqual( + schema.schema, + { + "c": { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "a": "string", + "b": "int", + }, + } + } + }, + ) + schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"}) + self.assertEqual( + schema.schema, + { + "c": { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "a": "string", + "b": "int", + }, + }, + "db2": { + "z": { + "c": "string", + "d": "int", + } + }, + } + }, + ) + schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"}) + self.assertEqual( + schema.schema, + { + "c": { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "a": "string", + "b": "int", + }, + }, + "db2": { + "z": { + "c": "string", + "d": "int", + } + }, + }, + "c2": { + "db2": { + "m": { + "e": "string", + "f": "int", + } + } + }, + }, + ) + + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual(schema.column_names(table("x")), ["a"]) + + schema = MappingSchema() + schema.add_table(table("x"), {"a": "string"}) + self.assertEqual( + schema.schema, + { + "x": { + "a": "string", + } + }, + ) + schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())])) + self.assertEqual( + schema.schema, + { + "x": { + "a": "string", + }, + "y": { + "b": "string", + }, + }, + ) + + def test_schema_add_table_with_and_without_mapping(self): + schema = MappingSchema() + schema.add_table("test") + self.assertEqual(schema.column_names("test"), []) + schema.add_table("test", {"x": "string"}) + self.assertEqual(schema.column_names("test"), ["x"]) + schema.add_table("test", {"x": "string", "y": "int"}) + self.assertEqual(schema.column_names("test"), ["x", "y"]) + schema.add_table("test") + self.assertEqual(schema.column_names("test"), ["x", "y"]) |