summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py5
-rw-r--r--sqlglot/__main__.py4
-rw-r--r--sqlglot/dataframe/README.md224
-rw-r--r--sqlglot/dataframe/__init__.py0
-rw-r--r--sqlglot/dataframe/sql/__init__.py18
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi20
-rw-r--r--sqlglot/dataframe/sql/column.py295
-rw-r--r--sqlglot/dataframe/sql/dataframe.py730
-rw-r--r--sqlglot/dataframe/sql/functions.py1258
-rw-r--r--sqlglot/dataframe/sql/group.py57
-rw-r--r--sqlglot/dataframe/sql/normalize.py72
-rw-r--r--sqlglot/dataframe/sql/operations.py53
-rw-r--r--sqlglot/dataframe/sql/readwriter.py79
-rw-r--r--sqlglot/dataframe/sql/session.py148
-rw-r--r--sqlglot/dataframe/sql/transforms.py9
-rw-r--r--sqlglot/dataframe/sql/types.py208
-rw-r--r--sqlglot/dataframe/sql/util.py32
-rw-r--r--sqlglot/dataframe/sql/window.py117
-rw-r--r--sqlglot/dialects/bigquery.py12
-rw-r--r--sqlglot/dialects/dialect.py13
-rw-r--r--sqlglot/dialects/hive.py4
-rw-r--r--sqlglot/dialects/mysql.py5
-rw-r--r--sqlglot/dialects/oracle.py11
-rw-r--r--sqlglot/dialects/postgres.py6
-rw-r--r--sqlglot/dialects/presto.py1
-rw-r--r--sqlglot/dialects/snowflake.py33
-rw-r--r--sqlglot/dialects/spark.py13
-rw-r--r--sqlglot/dialects/tsql.py5
-rw-r--r--sqlglot/executor/env.py1
-rw-r--r--sqlglot/executor/python.py8
-rw-r--r--sqlglot/expressions.py81
-rw-r--r--sqlglot/generator.py29
-rw-r--r--sqlglot/helper.py53
-rw-r--r--sqlglot/optimizer/__init__.py1
-rw-r--r--sqlglot/optimizer/annotate_types.py2
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py10
-rw-r--r--sqlglot/optimizer/merge_subqueries.py16
-rw-r--r--sqlglot/optimizer/optimizer.py5
-rw-r--r--sqlglot/optimizer/pushdown_projections.py35
-rw-r--r--sqlglot/optimizer/qualify_columns.py27
-rw-r--r--sqlglot/optimizer/qualify_tables.py2
-rw-r--r--sqlglot/optimizer/schema.py180
-rw-r--r--sqlglot/optimizer/scope.py11
-rw-r--r--sqlglot/parser.py26
-rw-r--r--sqlglot/planner.py3
-rw-r--r--sqlglot/schema.py297
-rw-r--r--sqlglot/tokens.py26
48 files changed, 3960 insertions, 287 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 247085b..7841c11 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -21,12 +21,15 @@ from sqlglot.expressions import table_ as table
from sqlglot.expressions import union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
+from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "7.1.3"
+__version__ = "9.0.1"
pretty = False
+schema = MappingSchema()
+
def parse(sql, read=None, **opts):
"""
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index 4161259..c0fa380 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -40,8 +40,8 @@ parser.add_argument(
"--error-level",
dest="error_level",
type=str,
- default="RAISE",
- help="IGNORE, WARN, RAISE (default)",
+ default="IMMEDIATE",
+ help="IGNORE, WARN, RAISE, IMMEDIATE (default)",
)
diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md
new file mode 100644
index 0000000..54d3856
--- /dev/null
+++ b/sqlglot/dataframe/README.md
@@ -0,0 +1,224 @@
+# PySpark DataFrame SQL Generator
+
+This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
+
+Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
+
+# How to use
+
+## Instructions
+* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library
+* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`
+* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`
+ * The column structure can be defined the following ways:
+ * Dictionary where the keys are column names and values are string of the Spark SQL type name
+ * Ex: {'cola': 'string', 'colb': 'int'}
+ * PySpark DataFrame `StructType` similar to when using `createDataFrame`
+ * Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
+ * A string of names and types similar to what is supported in `createDataFrame`
+ * Ex: `cola: STRING, colb: INT`
+ * [Not Recommended] A list of string column names without type
+ * Ex: ['cola', 'colb']
+ * The lack of types may limit functionality in future releases
+ * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally
+* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command
+ * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
+ * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects
+ * Ex: `.sql(pretty=True, dialect='bigquery')`
+
+## Examples
+
+```python
+import sqlglot
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import functions as F
+
+sqlglot.schema.add_table('employee', {
+ 'employee_id': 'INT',
+ 'fname': 'STRING',
+ 'lname': 'STRING',
+ 'age': 'INT',
+}) # Register the table structure prior to reading from the table
+
+spark = SparkSession()
+
+df = (
+ spark
+ .table('employee')
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+)
+
+print(df.sql(pretty=True)) # Spark will be the dialect used by default
+```
+Output:
+```sparksql
+SELECT
+ `employee`.`age` AS `age`,
+ COUNT(DISTINCT `employee`.`employee_id`) AS `num_employees`
+FROM `employee` AS `employee`
+GROUP BY
+ `employee`.`age`
+```
+
+## Registering Custom Schema Class
+
+The step of adding `sqlglot.schema.add_table` can be skipped if you have the column structure stored externally like in a file or from an external metadata table. This can be done by writing a class that implements the `sqlglot.schema.Schema` abstract class and then assigning that class to `sqlglot.schema`.
+
+```python
+import sqlglot
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.schema import Schema
+
+
+class ExternalSchema(Schema):
+ ...
+
+sqlglot.schema = ExternalSchema()
+
+spark = SparkSession()
+
+df = (
+ spark
+ .table('employee')
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+)
+
+print(df.sql(pretty=True))
+```
+
+## Example Implementations
+
+### Bigquery
+```python
+from google.cloud import bigquery
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import types
+from sqlglot.dataframe.sql import functions as F
+
+client = bigquery.Client()
+
+data = [
+ (1, "Jack", "Shephard", 34),
+ (2, "John", "Locke", 48),
+ (3, "Kate", "Austen", 34),
+ (4, "Claire", "Littleton", 22),
+ (5, "Hugo", "Reyes", 26),
+]
+schema = types.StructType([
+ types.StructField('employee_id', types.IntegerType(), False),
+ types.StructField('fname', types.StringType(), False),
+ types.StructField('lname', types.StringType(), False),
+ types.StructField('age', types.IntegerType(), False),
+])
+
+sql_statements = (
+ SparkSession()
+ .createDataFrame(data, schema)
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+ .sql(dialect="bigquery")
+)
+
+result = None
+for sql in sql_statements:
+ result = client.query(sql)
+
+assert result is not None
+for row in client.query(result):
+ print(f"Age: {row['age']}, Num Employees: {row['num_employees']}")
+```
+
+### Snowflake
+```python
+import os
+
+import snowflake.connector
+from sqlglot.dataframe.session import SparkSession
+from sqlglot.dataframe import types
+from sqlglot.dataframe import functions as F
+
+ctx = snowflake.connector.connect(
+ user=os.environ["SNOWFLAKE_USER"],
+ password=os.environ["SNOWFLAKE_PASS"],
+ account=os.environ["SNOWFLAKE_ACCOUNT"]
+)
+cs = ctx.cursor()
+
+data = [
+ (1, "Jack", "Shephard", 34),
+ (2, "John", "Locke", 48),
+ (3, "Kate", "Austen", 34),
+ (4, "Claire", "Littleton", 22),
+ (5, "Hugo", "Reyes", 26),
+]
+schema = types.StructType([
+ types.StructField('employee_id', types.IntegerType(), False),
+ types.StructField('fname', types.StringType(), False),
+ types.StructField('lname', types.StringType(), False),
+ types.StructField('age', types.IntegerType(), False),
+])
+
+sql_statements = (
+ SparkSession()
+ .createDataFrame(data, schema)
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("lname")).alias("num_employees"))
+ .sql(dialect="snowflake")
+)
+
+try:
+ for sql in sql_statements:
+ cs.execute(sql)
+ results = cs.fetchall()
+ for row in results:
+ print(f"Age: {row[0]}, Num Employees: {row[1]}")
+finally:
+ cs.close()
+ctx.close()
+```
+
+### Spark
+```python
+from pyspark.sql.session import SparkSession as PySparkSession
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql import types
+from sqlglot.dataframe.sql import functions as F
+
+data = [
+ (1, "Jack", "Shephard", 34),
+ (2, "John", "Locke", 48),
+ (3, "Kate", "Austen", 34),
+ (4, "Claire", "Littleton", 22),
+ (5, "Hugo", "Reyes", 26),
+]
+schema = types.StructType([
+ types.StructField('employee_id', types.IntegerType(), False),
+ types.StructField('fname', types.StringType(), False),
+ types.StructField('lname', types.StringType(), False),
+ types.StructField('age', types.IntegerType(), False),
+])
+
+sql_statements = (
+ SparkSession()
+ .createDataFrame(data, schema)
+ .groupBy(F.col("age"))
+ .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
+ .sql(dialect="bigquery")
+)
+
+pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
+
+df = None
+for sql in sql_statements:
+ df = pyspark.sql(sql)
+
+assert df is not None
+df.show()
+```
+
+# Unsupportable Operations
+
+Any operation that lacks a way to represent it in SQL cannot be supported by this tool. An example of this would be rdd operations. Since the DataFrame API though is mostly modeled around SQL concepts most operations can be supported.
diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/sqlglot/dataframe/__init__.py
diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py
new file mode 100644
index 0000000..3f90802
--- /dev/null
+++ b/sqlglot/dataframe/sql/__init__.py
@@ -0,0 +1,18 @@
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
+from sqlglot.dataframe.sql.group import GroupedData
+from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql.window import Window, WindowSpec
+
+__all__ = [
+ "SparkSession",
+ "DataFrame",
+ "GroupedData",
+ "Column",
+ "DataFrameNaFunctions",
+ "Window",
+ "WindowSpec",
+ "DataFrameReader",
+ "DataFrameWriter",
+]
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
new file mode 100644
index 0000000..f1a03ea
--- /dev/null
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -0,0 +1,20 @@
+from __future__ import annotations
+
+import datetime
+import typing as t
+
+from sqlglot import expressions as exp
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.column import Column
+ from sqlglot.dataframe.sql.types import StructType
+
+ColumnLiterals = t.TypeVar(
+ "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+)
+ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
+ColumnOrLiteral = t.TypeVar(
+ "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+)
+SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
+OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
new file mode 100644
index 0000000..2391080
--- /dev/null
+++ b/sqlglot/dataframe/sql/column.py
@@ -0,0 +1,295 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.types import DataType
+from sqlglot.helper import flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrLiteral
+ from sqlglot.dataframe.sql.window import WindowSpec
+
+
+class Column:
+ def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ if isinstance(expression, Column):
+ expression = expression.expression # type: ignore
+ elif expression is None or not isinstance(expression, (str, exp.Expression)):
+ expression = self._lit(expression).expression # type: ignore
+ self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
+
+ def __repr__(self):
+ return repr(self.expression)
+
+ def __hash__(self):
+ return hash(self.expression)
+
+ def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore
+ return self.binary_op(exp.EQ, other)
+
+ def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore
+ return self.binary_op(exp.NEQ, other)
+
+ def __gt__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.GT, other)
+
+ def __ge__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.GTE, other)
+
+ def __lt__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.LT, other)
+
+ def __le__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.LTE, other)
+
+ def __and__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.And, other)
+
+ def __or__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Or, other)
+
+ def __mod__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Mod, other)
+
+ def __add__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Add, other)
+
+ def __sub__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Sub, other)
+
+ def __mul__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Mul, other)
+
+ def __truediv__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Div, other)
+
+ def __div__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Div, other)
+
+ def __neg__(self) -> Column:
+ return self.unary_op(exp.Neg)
+
+ def __radd__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Add, other)
+
+ def __rsub__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Sub, other)
+
+ def __rmul__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Mul, other)
+
+ def __rdiv__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Div, other)
+
+ def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Div, other)
+
+ def __rmod__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Mod, other)
+
+ def __pow__(self, power: ColumnOrLiteral, modulo=None):
+ return Column(exp.Pow(this=self.expression, power=Column(power).expression))
+
+ def __rpow__(self, power: ColumnOrLiteral):
+ return Column(exp.Pow(this=Column(power).expression, power=self.expression))
+
+ def __invert__(self):
+ return self.unary_op(exp.Not)
+
+ def __rand__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.And, other)
+
+ def __ror__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Or, other)
+
+ @classmethod
+ def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ return cls(value)
+
+ @classmethod
+ def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
+ return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
+
+ @classmethod
+ def _lit(cls, value: ColumnOrLiteral) -> Column:
+ if isinstance(value, dict):
+ columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
+ return cls(exp.Struct(expressions=columns))
+ return cls(exp.convert(value))
+
+ @classmethod
+ def invoke_anonymous_function(
+ cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
+ ) -> Column:
+ columns = [] if column is None else [cls.ensure_col(column)]
+ column_args = [cls.ensure_col(arg) for arg in args]
+ expressions = [x.expression for x in columns + column_args]
+ new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
+ return Column(new_expression)
+
+ @classmethod
+ def invoke_expression_over_column(
+ cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
+ ) -> Column:
+ ensured_column = None if column is None else cls.ensure_col(column)
+ new_expression = (
+ callable_expression(**kwargs)
+ if ensured_column is None
+ else callable_expression(this=ensured_column.column_expression, **kwargs)
+ )
+ return Column(new_expression)
+
+ def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
+ return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
+
+ def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
+ return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
+
+ def unary_op(self, klass: t.Callable, **kwargs) -> Column:
+ return Column(klass(this=self.column_expression, **kwargs))
+
+ @property
+ def is_alias(self):
+ return isinstance(self.expression, exp.Alias)
+
+ @property
+ def is_column(self):
+ return isinstance(self.expression, exp.Column)
+
+ @property
+ def column_expression(self) -> exp.Column:
+ return self.expression.unalias()
+
+ @property
+ def alias_or_name(self) -> str:
+ return self.expression.alias_or_name
+
+ @classmethod
+ def ensure_literal(cls, value) -> Column:
+ from sqlglot.dataframe.sql.functions import lit
+
+ if isinstance(value, cls):
+ value = value.expression
+ if not isinstance(value, exp.Literal):
+ return lit(value)
+ return Column(value)
+
+ def copy(self) -> Column:
+ return Column(self.expression.copy())
+
+ def set_table_name(self, table_name: str, copy=False) -> Column:
+ expression = self.expression.copy() if copy else self.expression
+ expression.set("table", exp.to_identifier(table_name))
+ return Column(expression)
+
+ def sql(self, **kwargs) -> Column:
+ return self.expression.sql(**{"dialect": "spark", **kwargs})
+
+ def alias(self, name: str) -> Column:
+ new_expression = exp.alias_(self.column_expression, name)
+ return Column(new_expression)
+
+ def asc(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
+ return Column(new_expression)
+
+ def desc(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
+ return Column(new_expression)
+
+ asc_nulls_first = asc
+
+ def asc_nulls_last(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
+ return Column(new_expression)
+
+ def desc_nulls_first(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
+ return Column(new_expression)
+
+ desc_nulls_last = desc
+
+ def when(self, condition: Column, value: t.Any) -> Column:
+ from sqlglot.dataframe.sql.functions import when
+
+ column_with_if = when(condition, value)
+ if not isinstance(self.expression, exp.Case):
+ return column_with_if
+ new_column = self.copy()
+ new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
+ return new_column
+
+ def otherwise(self, value: t.Any) -> Column:
+ from sqlglot.dataframe.sql.functions import lit
+
+ true_value = value if isinstance(value, Column) else lit(value)
+ new_column = self.copy()
+ new_column.expression.set("default", true_value.column_expression)
+ return new_column
+
+ def isNull(self) -> Column:
+ new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
+ return Column(new_expression)
+
+ def isNotNull(self) -> Column:
+ new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
+ return Column(new_expression)
+
+ def cast(self, dataType: t.Union[str, DataType]):
+ """
+ Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
+ Sqlglot doesn't currently replicate this class so it only accepts a string
+ """
+ if isinstance(dataType, DataType):
+ dataType = dataType.simpleString()
+ new_expression = exp.Cast(this=self.column_expression, to=dataType)
+ return Column(new_expression)
+
+ def startswith(self, value: t.Union[str, Column]) -> Column:
+ value = self._lit(value) if not isinstance(value, Column) else value
+ return self.invoke_anonymous_function(self, "STARTSWITH", value)
+
+ def endswith(self, value: t.Union[str, Column]) -> Column:
+ value = self._lit(value) if not isinstance(value, Column) else value
+ return self.invoke_anonymous_function(self, "ENDSWITH", value)
+
+ def rlike(self, regexp: str) -> Column:
+ return self.invoke_expression_over_column(
+ column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
+ )
+
+ def like(self, other: str):
+ return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
+
+ def ilike(self, other: str):
+ return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
+
+ def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
+ startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
+ length = self._lit(length) if not isinstance(length, Column) else length
+ return Column.invoke_expression_over_column(
+ self, exp.Substring, start=startPos.expression, length=length.expression
+ )
+
+ def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
+ columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [self._lit(x).expression for x in columns]
+ return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
+
+ def between(
+ self,
+ lowerBound: t.Union[ColumnOrLiteral],
+ upperBound: t.Union[ColumnOrLiteral],
+ ) -> Column:
+ lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
+ upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ return Column(
+ exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
+ )
+
+ def over(self, window: WindowSpec) -> Column:
+ window_expression = window.expression.copy()
+ window_expression.set("this", self.column_expression)
+ return Column(window_expression)
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
new file mode 100644
index 0000000..322dcf2
--- /dev/null
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -0,0 +1,730 @@
+from __future__ import annotations
+
+import functools
+import typing as t
+import zlib
+from copy import copy
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.group import GroupedData
+from sqlglot.dataframe.sql.normalize import normalize
+from sqlglot.dataframe.sql.operations import Operation, operation
+from sqlglot.dataframe.sql.readwriter import DataFrameWriter
+from sqlglot.dataframe.sql.transforms import replace_id_value
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.dataframe.sql.window import Window
+from sqlglot.helper import ensure_list, object_to_dict
+from sqlglot.optimizer import optimize as optimize_func
+from sqlglot.optimizer.qualify_columns import qualify_columns
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+JOIN_HINTS = {
+ "BROADCAST",
+ "BROADCASTJOIN",
+ "MAPJOIN",
+ "MERGE",
+ "SHUFFLEMERGE",
+ "MERGEJOIN",
+ "SHUFFLE_HASH",
+ "SHUFFLE_REPLICATE_NL",
+}
+
+
+class DataFrame:
+ def __init__(
+ self,
+ spark: SparkSession,
+ expression: exp.Select,
+ branch_id: t.Optional[str] = None,
+ sequence_id: t.Optional[str] = None,
+ last_op: Operation = Operation.INIT,
+ pending_hints: t.Optional[t.List[exp.Expression]] = None,
+ output_expression_container: t.Optional[OutputExpressionContainer] = None,
+ **kwargs,
+ ):
+ self.spark = spark
+ self.expression = expression
+ self.branch_id = branch_id or self.spark._random_branch_id
+ self.sequence_id = sequence_id or self.spark._random_sequence_id
+ self.last_op = last_op
+ self.pending_hints = pending_hints or []
+ self.output_expression_container = output_expression_container or exp.Select()
+
+ def __getattr__(self, column_name: str) -> Column:
+ return self[column_name]
+
+ def __getitem__(self, column_name: str) -> Column:
+ column_name = f"{self.branch_id}.{column_name}"
+ return Column(column_name)
+
+ def __copy__(self):
+ return self.copy()
+
+ @property
+ def sparkSession(self):
+ return self.spark
+
+ @property
+ def write(self):
+ return DataFrameWriter(self)
+
+ @property
+ def latest_cte_name(self) -> str:
+ if not self.expression.ctes:
+ from_exp = self.expression.args["from"]
+ if from_exp.alias_or_name:
+ return from_exp.alias_or_name
+ table_alias = from_exp.find(exp.TableAlias)
+ if not table_alias:
+ raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
+ return table_alias.alias_or_name
+ return self.expression.ctes[-1].alias
+
+ @property
+ def pending_join_hints(self):
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
+
+ @property
+ def pending_partition_hints(self):
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
+
+ @property
+ def columns(self) -> t.List[str]:
+ return self.expression.named_selects
+
+ @property
+ def na(self) -> DataFrameNaFunctions:
+ return DataFrameNaFunctions(self)
+
+ def _replace_cte_names_with_hashes(self, expression: exp.Select):
+ expression = expression.copy()
+ ctes = expression.ctes
+ replacement_mapping = {}
+ for cte in ctes:
+ old_name_id = cte.args["alias"].this
+ new_hashed_id = exp.to_identifier(
+ self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
+ )
+ replacement_mapping[old_name_id] = new_hashed_id
+ cte.set("alias", exp.TableAlias(this=new_hashed_id))
+ expression = expression.transform(replace_id_value, replacement_mapping)
+ return expression
+
+ def _create_cte_from_expression(
+ self,
+ expression: exp.Expression,
+ branch_id: t.Optional[str] = None,
+ sequence_id: t.Optional[str] = None,
+ **kwargs,
+ ) -> t.Tuple[exp.CTE, str]:
+ name = self.spark._random_name
+ expression_to_cte = expression.copy()
+ expression_to_cte.set("with", None)
+ cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
+ cte.set("branch_id", branch_id or self.branch_id)
+ cte.set("sequence_id", sequence_id or self.sequence_id)
+ return cte, name
+
+ def _ensure_list_of_columns(
+ self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
+ ) -> t.List[Column]:
+ columns = ensure_list(cols)
+ columns = Column.ensure_cols(columns)
+ return columns
+
+ def _ensure_and_normalize_cols(self, cols):
+ cols = self._ensure_list_of_columns(cols)
+ normalize(self.spark, self.expression, cols)
+ return cols
+
+ def _ensure_and_normalize_col(self, col):
+ col = Column.ensure_col(col)
+ normalize(self.spark, self.expression, col)
+ return col
+
+ def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
+ df = self._resolve_pending_hints()
+ sequence_id = sequence_id or df.sequence_id
+ expression = df.expression.copy()
+ cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
+ new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
+ sel_columns = df._get_outer_select_columns(cte_expression)
+ new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
+ return df.copy(expression=new_expression, sequence_id=sequence_id)
+
+ def _resolve_pending_hints(self) -> DataFrame:
+ df = self.copy()
+ if not self.pending_hints:
+ return df
+ expression = df.expression
+ hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
+ for hint in df.pending_partition_hints:
+ hint_expression.args.get("expressions").append(hint)
+ df.pending_hints.remove(hint)
+
+ join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
+ if join_aliases:
+ for hint in df.pending_join_hints:
+ for sequence_id_expression in hint.expressions:
+ sequence_id_or_name = sequence_id_expression.alias_or_name
+ sequence_ids_to_match = [sequence_id_or_name]
+ if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
+ matching_ctes = [
+ cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
+ ]
+ for matching_cte in matching_ctes:
+ if matching_cte.alias_or_name in join_aliases:
+ sequence_id_expression.set("this", matching_cte.args["alias"].this)
+ df.pending_hints.remove(hint)
+ break
+ hint_expression.args.get("expressions").append(hint)
+ if hint_expression.expressions:
+ expression.set("hint", hint_expression)
+ return df
+
+ def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
+ hint_name = hint_name.upper()
+ hint_expression = (
+ exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
+ if hint_name in JOIN_HINTS
+ else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
+ )
+ new_df = self.copy()
+ new_df.pending_hints.append(hint_expression)
+ return new_df
+
+ def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
+ other_df = other._convert_leaf_to_cte()
+ base_expression = self.expression.copy()
+ base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
+ all_ctes = base_expression.ctes
+ other_df.expression.set("with", None)
+ base_expression.set("with", None)
+ operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
+ operation.set("with", exp.With(expressions=all_ctes))
+ return self.copy(expression=operation)._convert_leaf_to_cte()
+
+ def _cache(self, storage_level: str):
+ df = self._convert_leaf_to_cte()
+ df.expression.ctes[-1].set("cache_storage_level", storage_level)
+ return df
+
+ @classmethod
+ def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
+ expression = expression.copy()
+ with_expression = expression.args.get("with")
+ if with_expression:
+ existing_ctes = with_expression.expressions
+ existsing_cte_names = {x.alias_or_name for x in existing_ctes}
+ for cte in ctes:
+ if cte.alias_or_name not in existsing_cte_names:
+ existing_ctes.append(cte)
+ else:
+ existing_ctes = ctes
+ expression.set("with", exp.With(expressions=existing_ctes))
+ return expression
+
+ @classmethod
+ def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
+ expression = item.expression if isinstance(item, DataFrame) else item
+ return [Column(x) for x in expression.find(exp.Select).expressions]
+
+ @classmethod
+ def _create_hash_from_expression(cls, expression: exp.Select):
+ value = expression.sql(dialect="spark").encode("utf-8")
+ return f"t{zlib.crc32(value)}"[:6]
+
+ def _get_select_expressions(
+ self,
+ ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
+ select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
+ main_select_ctes: t.List[exp.CTE] = []
+ for cte in self.expression.ctes:
+ cache_storage_level = cte.args.get("cache_storage_level")
+ if cache_storage_level:
+ select_expression = cte.this.copy()
+ select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
+ select_expression.set("cte_alias_name", cte.alias_or_name)
+ select_expression.set("cache_storage_level", cache_storage_level)
+ select_expressions.append((exp.Cache, select_expression))
+ else:
+ main_select_ctes.append(cte)
+ main_select = self.expression.copy()
+ if main_select_ctes:
+ main_select.set("with", exp.With(expressions=main_select_ctes))
+ expression_select_pair = (type(self.output_expression_container), main_select)
+ select_expressions.append(expression_select_pair) # type: ignore
+ return select_expressions
+
+ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
+ df = self._resolve_pending_hints()
+ select_expressions = df._get_select_expressions()
+ output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
+ replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
+ for expression_type, select_expression in select_expressions:
+ select_expression = select_expression.transform(replace_id_value, replacement_mapping)
+ if optimize:
+ select_expression = optimize_func(select_expression)
+ select_expression = df._replace_cte_names_with_hashes(select_expression)
+ expression: t.Union[exp.Select, exp.Cache, exp.Drop]
+ if expression_type == exp.Cache:
+ cache_table_name = df._create_hash_from_expression(select_expression)
+ cache_table = exp.to_table(cache_table_name)
+ original_alias_name = select_expression.args["cte_alias_name"]
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
+ sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
+ cache_storage_level = select_expression.args["cache_storage_level"]
+ options = [
+ exp.Literal.string("storageLevel"),
+ exp.Literal.string(cache_storage_level),
+ ]
+ expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
+ # We will drop the "view" if it exists before running the cache table
+ output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
+ elif expression_type == exp.Create:
+ expression = df.output_expression_container.copy()
+ expression.set("expression", select_expression)
+ elif expression_type == exp.Insert:
+ expression = df.output_expression_container.copy()
+ select_without_ctes = select_expression.copy()
+ select_without_ctes.set("with", None)
+ expression.set("expression", select_without_ctes)
+ if select_expression.ctes:
+ expression.set("with", exp.With(expressions=select_expression.ctes))
+ elif expression_type == exp.Select:
+ expression = select_expression
+ else:
+ raise ValueError(f"Invalid expression type: {expression_type}")
+ output_expressions.append(expression)
+
+ return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
+
+ def copy(self, **kwargs) -> DataFrame:
+ return DataFrame(**object_to_dict(self, **kwargs))
+
+ @operation(Operation.SELECT)
+ def select(self, *cols, **kwargs) -> DataFrame:
+ cols = self._ensure_and_normalize_cols(cols)
+ kwargs["append"] = kwargs.get("append", False)
+ if self.expression.args.get("joins"):
+ ambiguous_cols = [col for col in cols if not col.column_expression.table]
+ if ambiguous_cols:
+ join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
+ cte_names_in_join = [x.this for x in join_table_identifiers]
+ for ambiguous_col in ambiguous_cols:
+ ctes_with_column = [
+ cte
+ for cte in self.expression.ctes
+ if cte.alias_or_name in cte_names_in_join
+ and ambiguous_col.alias_or_name in cte.this.named_selects
+ ]
+ # If the select column does not specify a table and there is a join
+ # then we assume they are referring to the left table
+ if len(ctes_with_column) > 1:
+ table_identifier = self.expression.args["from"].args["expressions"][0].this
+ else:
+ table_identifier = ctes_with_column[0].args["alias"].this
+ ambiguous_col.expression.set("table", table_identifier)
+ expression = self.expression.select(*[x.expression for x in cols], **kwargs)
+ qualify_columns(expression, sqlglot.schema)
+ return self.copy(expression=expression, **kwargs)
+
+ @operation(Operation.NO_OP)
+ def alias(self, name: str, **kwargs) -> DataFrame:
+ new_sequence_id = self.spark._random_sequence_id
+ df = self.copy()
+ for join_hint in df.pending_join_hints:
+ for expression in join_hint.expressions:
+ if expression.alias_or_name == self.sequence_id:
+ expression.set("this", Column.ensure_col(new_sequence_id).expression)
+ df.spark._add_alias_to_mapping(name, new_sequence_id)
+ return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
+
+ @operation(Operation.WHERE)
+ def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
+ col = self._ensure_and_normalize_col(column)
+ return self.copy(expression=self.expression.where(col.expression))
+
+ filter = where
+
+ @operation(Operation.GROUP_BY)
+ def groupBy(self, *cols, **kwargs) -> GroupedData:
+ columns = self._ensure_and_normalize_cols(cols)
+ return GroupedData(self, columns, self.last_op)
+
+ @operation(Operation.SELECT)
+ def agg(self, *exprs, **kwargs) -> DataFrame:
+ cols = self._ensure_and_normalize_cols(exprs)
+ return self.groupBy().agg(*cols)
+
+ @operation(Operation.FROM)
+ def join(
+ self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
+ ) -> DataFrame:
+ other_df = other_df._convert_leaf_to_cte()
+ pre_join_self_latest_cte_name = self.latest_cte_name
+ columns = self._ensure_and_normalize_cols(on)
+ join_type = how.replace("_", " ")
+ if isinstance(columns[0].expression, exp.Column):
+ join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
+ join_clause = functools.reduce(
+ lambda x, y: x & y,
+ [
+ col.copy().set_table_name(pre_join_self_latest_cte_name)
+ == col.copy().set_table_name(other_df.latest_cte_name)
+ for col in columns
+ ],
+ )
+ else:
+ if len(columns) > 1:
+ columns = [functools.reduce(lambda x, y: x & y, columns)]
+ join_clause = columns[0]
+ join_columns = [
+ Column(x).set_table_name(pre_join_self_latest_cte_name)
+ if i % 2 == 0
+ else Column(x).set_table_name(other_df.latest_cte_name)
+ for i, x in enumerate(join_clause.expression.find_all(exp.Column))
+ ]
+ self_columns = [
+ column.set_table_name(pre_join_self_latest_cte_name, copy=True)
+ for column in self._get_outer_select_columns(self)
+ ]
+ other_columns = [
+ column.set_table_name(other_df.latest_cte_name, copy=True)
+ for column in self._get_outer_select_columns(other_df)
+ ]
+ column_value_mapping = {
+ column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
+ for column in other_columns + self_columns + join_columns
+ }
+ all_columns = [
+ column_value_mapping[name]
+ for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
+ ]
+ new_df = self.copy(
+ expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
+ )
+ new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
+ new_df.pending_hints.extend(other_df.pending_hints)
+ new_df = new_df.select.__wrapped__(new_df, *all_columns)
+ return new_df
+
+ @operation(Operation.ORDER_BY)
+ def orderBy(
+ self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
+ ) -> DataFrame:
+ """
+ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
+ has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
+ is unlikely to come up.
+ """
+ columns = self._ensure_and_normalize_cols(cols)
+ pre_ordered_col_indexes = [
+ x
+ for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
+ if x is not None
+ ]
+ if ascending is None:
+ ascending = [True] * len(columns)
+ elif not isinstance(ascending, list):
+ ascending = [ascending] * len(columns)
+ ascending = [bool(x) for i, x in enumerate(ascending)]
+ assert len(columns) == len(
+ ascending
+ ), "The length of items in ascending must equal the number of columns provided"
+ col_and_ascending = list(zip(columns, ascending))
+ order_by_columns = [
+ exp.Ordered(this=col.expression, desc=not asc)
+ if i not in pre_ordered_col_indexes
+ else columns[i].column_expression
+ for i, (col, asc) in enumerate(col_and_ascending)
+ ]
+ return self.copy(expression=self.expression.order_by(*order_by_columns))
+
+ sort = orderBy
+
+ @operation(Operation.FROM)
+ def union(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Union, other, False)
+
+ unionAll = union
+
+ @operation(Operation.FROM)
+ def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
+ l_columns = self.columns
+ r_columns = other.columns
+ if not allowMissingColumns:
+ l_expressions = l_columns
+ r_expressions = l_columns
+ else:
+ l_expressions = []
+ r_expressions = []
+ r_columns_unused = copy(r_columns)
+ for l_column in l_columns:
+ l_expressions.append(l_column)
+ if l_column in r_columns:
+ r_expressions.append(l_column)
+ r_columns_unused.remove(l_column)
+ else:
+ r_expressions.append(exp.alias_(exp.Null(), l_column))
+ for r_column in r_columns_unused:
+ l_expressions.append(exp.alias_(exp.Null(), r_column))
+ r_expressions.append(r_column)
+ r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ l_df = self.copy()
+ if allowMissingColumns:
+ l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
+ return l_df._set_operation(exp.Union, r_df, False)
+
+ @operation(Operation.FROM)
+ def intersect(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Intersect, other, True)
+
+ @operation(Operation.FROM)
+ def intersectAll(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Intersect, other, False)
+
+ @operation(Operation.FROM)
+ def exceptAll(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Except, other, False)
+
+ @operation(Operation.SELECT)
+ def distinct(self) -> DataFrame:
+ return self.copy(expression=self.expression.distinct())
+
+ @operation(Operation.SELECT)
+ def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
+ if not subset:
+ return self.distinct()
+ column_names = ensure_list(subset)
+ window = Window.partitionBy(*column_names).orderBy(*column_names)
+ return (
+ self.copy()
+ .withColumn("row_num", F.row_number().over(window))
+ .where(F.col("row_num") == F.lit(1))
+ .drop("row_num")
+ )
+
+ @operation(Operation.FROM)
+ def dropna(
+ self,
+ how: str = "any",
+ thresh: t.Optional[int] = None,
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ minimum_non_null = thresh or 0 # will be determined later if thresh is null
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ if subset:
+ null_check_columns = self._ensure_and_normalize_cols(subset)
+ else:
+ null_check_columns = all_columns
+ if thresh is None:
+ minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
+ else:
+ minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
+ if minimum_num_nulls > len(null_check_columns):
+ raise RuntimeError(
+ f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
+ f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
+ )
+ if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
+ nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
+ num_nulls = nulls_added_together.alias("num_nulls")
+ new_df = new_df.select(num_nulls, append=True)
+ filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
+ final_df = filtered_df.select(*all_columns)
+ return final_df
+
+ @operation(Operation.FROM)
+ def fillna(
+ self,
+ value: t.Union[ColumnLiterals],
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ """
+ Functionality Difference: If you provide a value to replace a null and that type conflicts
+ with the type of the column then PySpark will just ignore your replacement.
+ This will try to cast them to be the same in some cases. So they won't always match.
+ Best to not mix types so make sure replacement is the same type as the column
+
+ Possibility for improvement: Use `typeof` function to get the type of the column
+ and check if it matches the type of the value provided. If not then make it null.
+ """
+ from sqlglot.dataframe.sql.functions import lit
+
+ values = None
+ columns = None
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
+ if isinstance(value, dict):
+ values = value.values()
+ columns = self._ensure_and_normalize_cols(list(value))
+ if not columns:
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
+ if not values:
+ values = [value] * len(columns)
+ value_columns = [lit(value) for value in values]
+
+ null_replacement_mapping = {
+ column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
+ for column, value in zip(columns, value_columns)
+ }
+ null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
+ null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
+ new_df = new_df.select(*null_replacement_columns)
+ return new_df
+
+ @operation(Operation.FROM)
+ def replace(
+ self,
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ ) -> DataFrame:
+ from sqlglot.dataframe.sql.functions import lit
+
+ old_values = None
+ subset = ensure_list(subset)
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
+
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
+ if isinstance(to_replace, dict):
+ old_values = list(to_replace)
+ new_values = list(to_replace.values())
+ elif not old_values and isinstance(to_replace, list):
+ assert isinstance(value, list), "value must be a list since the replacements are a list"
+ assert len(to_replace) == len(value), "the replacements and values must be the same length"
+ old_values = to_replace
+ new_values = value
+ else:
+ old_values = [to_replace] * len(columns)
+ new_values = [value] * len(columns)
+ old_values = [lit(value) for value in old_values]
+ new_values = [lit(value) for value in new_values]
+
+ replacement_mapping = {}
+ for column in columns:
+ expression = Column(None)
+ for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
+ if i == 0:
+ expression = F.when(column == old_value, new_value)
+ else:
+ expression = expression.when(column == old_value, new_value) # type: ignore
+ replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
+ column.expression.alias_or_name
+ )
+
+ replacement_mapping = {**all_column_mapping, **replacement_mapping}
+ replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
+ new_df = new_df.select(*replacement_columns)
+ return new_df
+
+ @operation(Operation.SELECT)
+ def withColumn(self, colName: str, col: Column) -> DataFrame:
+ col = self._ensure_and_normalize_col(col)
+ existing_col_names = self.expression.named_selects
+ existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
+ if existing_col_index:
+ expression = self.expression.copy()
+ expression.expressions[existing_col_index] = col.expression
+ return self.copy(expression=expression)
+ return self.copy().select(col.alias(colName), append=True)
+
+ @operation(Operation.SELECT)
+ def withColumnRenamed(self, existing: str, new: str):
+ expression = self.expression.copy()
+ existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
+ if not existing_columns:
+ raise ValueError("Tried to rename a column that doesn't exist")
+ for existing_column in existing_columns:
+ if isinstance(existing_column, exp.Column):
+ existing_column.replace(exp.alias_(existing_column.copy(), new))
+ else:
+ existing_column.set("alias", exp.to_identifier(new))
+ return self.copy(expression=expression)
+
+ @operation(Operation.SELECT)
+ def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
+ all_columns = self._get_outer_select_columns(self.expression)
+ drop_cols = self._ensure_and_normalize_cols(cols)
+ new_columns = [
+ col
+ for col in all_columns
+ if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
+ ]
+ return self.copy().select(*new_columns, append=False)
+
+ @operation(Operation.LIMIT)
+ def limit(self, num: int) -> DataFrame:
+ return self.copy(expression=self.expression.limit(num))
+
+ @operation(Operation.NO_OP)
+ def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
+ parameter_list = ensure_list(parameters)
+ parameter_columns = (
+ self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
+ )
+ return self._hint(name, parameter_columns)
+
+ @operation(Operation.NO_OP)
+ def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
+ num_partitions = Column.ensure_cols(ensure_list(numPartitions))
+ columns = self._ensure_and_normalize_cols(cols)
+ args = num_partitions + columns
+ return self._hint("repartition", args)
+
+ @operation(Operation.NO_OP)
+ def coalesce(self, numPartitions: int) -> DataFrame:
+ num_partitions = Column.ensure_cols([numPartitions])
+ return self._hint("coalesce", num_partitions)
+
+ @operation(Operation.NO_OP)
+ def cache(self) -> DataFrame:
+ return self._cache(storage_level="MEMORY_AND_DISK")
+
+ @operation(Operation.NO_OP)
+ def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
+ """
+ Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
+ """
+ return self._cache(storageLevel)
+
+
+class DataFrameNaFunctions:
+ def __init__(self, df: DataFrame):
+ self.df = df
+
+ def drop(
+ self,
+ how: str = "any",
+ thresh: t.Optional[int] = None,
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
+
+ def fill(
+ self,
+ value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.fillna(value=value, subset=subset)
+
+ def replace(
+ self,
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.replace(to_replace=to_replace, value=value, subset=subset)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
new file mode 100644
index 0000000..4c6de30
--- /dev/null
+++ b/sqlglot/dataframe/sql/functions.py
@@ -0,0 +1,1258 @@
+from __future__ import annotations
+
+import typing as t
+from inspect import signature
+
+from sqlglot import expressions as glotexp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.helper import ensure_list
+from sqlglot.helper import flatten as _flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+
+def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
+ return Column(column_name)
+
+
+def lit(value: t.Optional[t.Any] = None) -> Column:
+ if isinstance(value, str):
+ return Column(glotexp.Literal.string(str(value)))
+ return Column(value)
+
+
+def greatest(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def least(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(x) for x in [col] + list(cols)]
+ return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns])))
+
+
+def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
+ return count_distinct(col, *cols)
+
+
+def when(condition: Column, value: t.Any) -> Column:
+ true_value = value if isinstance(value, Column) else lit(value)
+ return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
+
+
+def asc(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc()
+
+
+def desc(col: ColumnOrName):
+ return Column.ensure_col(col).desc()
+
+
+def broadcast(df: DataFrame) -> DataFrame:
+ return df.hint("broadcast")
+
+
+def sqrt(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Sqrt)
+
+
+def abs(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Abs)
+
+
+def max(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Max)
+
+
+def min(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Min)
+
+
+def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAX_BY", ord)
+
+
+def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MIN_BY", ord)
+
+
+def count(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Count)
+
+
+def sum(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Sum)
+
+
+def avg(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Avg)
+
+
+def mean(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MEAN")
+
+
+def sumDistinct(col: ColumnOrName) -> Column:
+ return sum_distinct(col)
+
+
+def sum_distinct(col: ColumnOrName) -> Column:
+ raise NotImplementedError("Sum distinct is not currently implemented")
+
+
+def product(col: ColumnOrName) -> Column:
+ raise NotImplementedError("Product is not currently implemented")
+
+
+def acos(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ACOS")
+
+
+def acosh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ACOSH")
+
+
+def asin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ASIN")
+
+
+def asinh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ASINH")
+
+
+def atan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ATAN")
+
+
+def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "ATAN2", col2)
+
+
+def atanh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ATANH")
+
+
+def cbrt(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "CBRT")
+
+
+def ceil(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Ceil)
+
+
+def cos(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COS")
+
+
+def cosh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COSH")
+
+
+def cot(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COT")
+
+
+def csc(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "CSC")
+
+
+def exp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Exp)
+
+
+def expm1(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "EXPM1")
+
+
+def floor(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Floor)
+
+
+def log10(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Log10)
+
+
+def log1p(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LOG1P")
+
+
+def log2(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Log2)
+
+
+def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
+ if arg2 is None:
+ return Column.invoke_expression_over_column(arg1, glotexp.Ln)
+ return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression)
+
+
+def rint(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RINT")
+
+
+def sec(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SEC")
+
+
+def signum(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SIGNUM")
+
+
+def sin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SIN")
+
+
+def sinh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SINH")
+
+
+def tan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TAN")
+
+
+def tanh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TANH")
+
+
+def toDegrees(col: ColumnOrName) -> Column:
+ return degrees(col)
+
+
+def degrees(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DEGREES")
+
+
+def toRadians(col: ColumnOrName) -> Column:
+ return radians(col)
+
+
+def radians(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RADIANS")
+
+
+def bitwiseNOT(col: ColumnOrName) -> Column:
+ return bitwise_not(col)
+
+
+def bitwise_not(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.BitwiseNot)
+
+
+def asc_nulls_first(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc_nulls_first()
+
+
+def asc_nulls_last(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc_nulls_last()
+
+
+def desc_nulls_first(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).desc_nulls_first()
+
+
+def desc_nulls_last(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).desc_nulls_last()
+
+
+def stddev(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Stddev)
+
+
+def stddev_samp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.StddevSamp)
+
+
+def stddev_pop(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.StddevPop)
+
+
+def variance(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Variance)
+
+
+def var_samp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Variance)
+
+
+def var_pop(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.VariancePop)
+
+
+def skewness(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SKEWNESS")
+
+
+def kurtosis(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "KURTOSIS")
+
+
+def collect_list(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArrayAgg)
+
+
+def collect_set(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.SetAgg)
+
+
+def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "HYPOT", col2)
+
+
+def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "POW", col2)
+
+
+def row_number() -> Column:
+ return Column(glotexp.Anonymous(this="ROW_NUMBER"))
+
+
+def dense_rank() -> Column:
+ return Column(glotexp.Anonymous(this="DENSE_RANK"))
+
+
+def rank() -> Column:
+ return Column(glotexp.Anonymous(this="RANK"))
+
+
+def cume_dist() -> Column:
+ return Column(glotexp.Anonymous(this="CUME_DIST"))
+
+
+def percent_rank() -> Column:
+ return Column(glotexp.Anonymous(this="PERCENT_RANK"))
+
+
+def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
+ return approx_count_distinct(col, rsd)
+
+
+def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
+ if rsd is None:
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression)
+
+
+def coalesce(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "CORR", col2)
+
+
+def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
+
+
+def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
+
+
+def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
+ if ignorenulls is not None:
+ return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
+ return Column.invoke_anonymous_function(col, "FIRST")
+
+
+def grouping_id(*cols: ColumnOrName) -> Column:
+ if not cols:
+ return Column.invoke_anonymous_function(None, "GROUPING_ID")
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "GROUPING_ID")
+ return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:])
+
+
+def input_file_name() -> Column:
+ return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME")
+
+
+def isnan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ISNAN")
+
+
+def isnull(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ISNULL")
+
+
+def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
+ if ignorenulls is not None:
+ return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
+ return Column.invoke_anonymous_function(col, "LAST")
+
+
+def monotonically_increasing_id() -> Column:
+ return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID")
+
+
+def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "NANVL", col2)
+
+
+def percentile_approx(
+ col: ColumnOrName,
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
+ accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None,
+) -> Column:
+ if accuracy:
+ return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy)
+ return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage)
+
+
+def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
+ return Column.invoke_anonymous_function(seed, "RAND")
+
+
+def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
+ return Column.invoke_anonymous_function(seed, "RANDN")
+
+
+def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
+ if scale is not None:
+ return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale))
+ return Column.invoke_expression_over_column(col, glotexp.Round)
+
+
+def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
+ if scale is not None:
+ return Column.invoke_anonymous_function(col, "BROUND", scale)
+ return Column.invoke_anonymous_function(col, "BROUND")
+
+
+def shiftleft(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_expression_over_column(
+ col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression
+ )
+
+
+def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
+ return shiftleft(col, numBits)
+
+
+def shiftright(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_expression_over_column(
+ col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression
+ )
+
+
+def shiftRight(col: ColumnOrName, numBits: int) -> Column:
+ return shiftright(col, numBits)
+
+
+def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits)
+
+
+def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column:
+ return shiftrightunsigned(col, numBits)
+
+
+def expr(str: str) -> Column:
+ return Column(str)
+
+
+def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
+ columns = ensure_list(col) + list(cols)
+ expressions = [Column.ensure_col(column).expression for column in columns]
+ return Column(glotexp.Struct(expressions=expressions))
+
+
+def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
+ return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase)
+
+
+def factorial(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "FACTORIAL")
+
+
+def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
+ if default is not None:
+ return Column.invoke_anonymous_function(col, "LAG", offset, default)
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "LAG", offset)
+ return Column.invoke_anonymous_function(col, "LAG")
+
+
+def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
+ if default is not None:
+ return Column.invoke_anonymous_function(col, "LEAD", offset, default)
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "LEAD", offset)
+ return Column.invoke_anonymous_function(col, "LEAD")
+
+
+def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
+ if ignoreNulls is not None:
+ raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
+ return Column.invoke_anonymous_function(col, "NTH_VALUE")
+
+
+def ntile(n: int) -> Column:
+ return Column.invoke_anonymous_function(None, "NTILE", n)
+
+
+def current_date() -> Column:
+ return Column.invoke_expression_over_column(None, glotexp.CurrentDate)
+
+
+def current_timestamp() -> Column:
+ return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp)
+
+
+def date_format(col: ColumnOrName, format: str) -> Column:
+ return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format))
+
+
+def year(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Year)
+
+
+def quarter(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "QUARTER")
+
+
+def month(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Month)
+
+
+def dayofweek(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFWEEK")
+
+
+def dayofmonth(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFMONTH")
+
+
+def dayofyear(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFYEAR")
+
+
+def hour(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "HOUR")
+
+
+def minute(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MINUTE")
+
+
+def second(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SECOND")
+
+
+def weekofyear(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
+
+
+def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day)
+
+
+def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression)
+
+
+def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression)
+
+
+def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression)
+
+
+def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
+
+
+def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
+ if roundOff is None:
+ return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
+ return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
+
+
+def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "TO_DATE", lit(format))
+ return Column.invoke_anonymous_function(col, "TO_DATE")
+
+
+def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format))
+ return Column.invoke_anonymous_function(col, "TO_TIMESTAMP")
+
+
+def trunc(col: ColumnOrName, format: str) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression)
+
+
+def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression)
+
+
+def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
+ return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek))
+
+
+def last_day(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LAST_DAY")
+
+
+def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format))
+ return Column.invoke_anonymous_function(col, "FROM_UNIXTIME")
+
+
+def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format))
+ return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP")
+
+
+def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
+ tz_column = tz if isinstance(tz, Column) else lit(tz)
+ return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column)
+
+
+def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
+ tz_column = tz if isinstance(tz, Column) else lit(tz)
+ return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
+
+
+def timestamp_seconds(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS")
+
+
+def window(
+ timeColumn: ColumnOrName,
+ windowDuration: str,
+ slideDuration: t.Optional[str] = None,
+ startTime: t.Optional[str] = None,
+) -> Column:
+ if slideDuration is not None and startTime is not None:
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
+ )
+ if slideDuration is not None:
+ return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
+ if startTime is not None:
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
+ )
+ return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration))
+
+
+def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column:
+ gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration)
+ return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column)
+
+
+def crc32(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "CRC32")
+
+
+def md5(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "MD5")
+
+
+def sha1(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "SHA1")
+
+
+def sha2(col: ColumnOrName, numBits: int) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ num_bits = lit(numBits)
+ return Column.invoke_anonymous_function(column, "SHA2", num_bits)
+
+
+def hash(*cols: ColumnOrName) -> Column:
+ args = cols[1:] if len(cols) > 1 else []
+ return Column.invoke_anonymous_function(cols[0], "HASH", *args)
+
+
+def xxhash64(*cols: ColumnOrName) -> Column:
+ args = cols[1:] if len(cols) > 1 else []
+ return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args)
+
+
+def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column:
+ if errorMsg is not None:
+ error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
+ return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col)
+ return Column.invoke_anonymous_function(col, "ASSERT_TRUE")
+
+
+def raise_error(errorMsg: ColumnOrName) -> Column:
+ error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
+ return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR")
+
+
+def upper(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Upper)
+
+
+def lower(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Lower)
+
+
+def ascii(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "ASCII")
+
+
+def base64(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "BASE64")
+
+
+def unbase64(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "UNBASE64")
+
+
+def ltrim(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LTRIM")
+
+
+def rtrim(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RTRIM")
+
+
+def trim(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Trim)
+
+
+def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
+ columns = [Column(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)]
+ )
+
+
+def decode(col: ColumnOrName, charset: str) -> Column:
+ return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
+
+
+def encode(col: ColumnOrName, charset: str) -> Column:
+ return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
+
+
+def format_number(col: ColumnOrName, d: int) -> Column:
+ return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d))
+
+
+def format_string(format: str, *cols: ColumnOrName) -> Column:
+ format_col = lit(format)
+ columns = [Column.ensure_col(x) for x in cols]
+ return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns)
+
+
+def instr(col: ColumnOrName, substr: str) -> Column:
+ return Column.invoke_anonymous_function(col, "INSTR", lit(substr))
+
+
+def overlay(
+ src: ColumnOrName,
+ replace: ColumnOrName,
+ pos: t.Union[ColumnOrName, int],
+ len: t.Optional[t.Union[ColumnOrName, int]] = None,
+) -> Column:
+ if len is not None:
+ return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len)
+ return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos)
+
+
+def sentences(
+ string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
+) -> Column:
+ if language is not None and country is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
+ if language is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", language)
+ if country is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country)
+ return Column.invoke_anonymous_function(string, "SENTENCES")
+
+
+def substring(str: ColumnOrName, pos: int, len: int) -> Column:
+ return Column.ensure_col(str).substr(pos, len)
+
+
+def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
+ return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count))
+
+
+def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(
+ left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression
+ )
+
+
+def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
+ substr_col = lit(substr)
+ pos_column = lit(pos)
+ str_column = Column.ensure_col(str)
+ if pos is not None:
+ return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column)
+ return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column)
+
+
+def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
+ return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad))
+
+
+def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
+ return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad))
+
+
+def repeat(col: ColumnOrName, n: int) -> Column:
+ return Column.invoke_anonymous_function(col, "REPEAT", n)
+
+
+def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
+ if limit is not None:
+ return Column.invoke_expression_over_column(
+ str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression
+ )
+ return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression)
+
+
+def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
+ if idx is not None:
+ return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx)
+ return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern))
+
+
+def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
+ return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
+
+
+def initcap(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Initcap)
+
+
+def soundex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SOUNDEX")
+
+
+def bin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "BIN")
+
+
+def hex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "HEX")
+
+
+def unhex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "UNHEX")
+
+
+def length(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Length)
+
+
+def octet_length(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "OCTET_LENGTH")
+
+
+def bit_length(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "BIT_LENGTH")
+
+
+def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
+ return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace))
+
+
+def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ cols = [Column.ensure_col(col).expression for col in cols] # type: ignore
+ return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols)
+
+
+def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ return Column.invoke_expression_over_column(
+ None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
+ )
+
+
+def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2)
+
+
+def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
+
+
+def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
+
+
+def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
+ start_col = start if isinstance(start, Column) else lit(start)
+ length_col = length if isinstance(length, Column) else lit(length)
+ return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
+
+
+def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
+ if null_replacement is not None:
+ return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
+ return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
+
+
+def concat(*cols: ColumnOrName) -> Column:
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "CONCAT")
+ return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
+
+
+def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col)
+
+
+def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col)
+
+
+def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col)
+
+
+def array_distinct(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT")
+
+
+def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2))
+
+
+def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2))
+
+
+def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2))
+
+
+def explode(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Explode)
+
+
+def posexplode(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Posexplode)
+
+
+def explode_outer(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "EXPLODE_OUTER")
+
+
+def posexplode_outer(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER")
+
+
+def get_json_object(col: ColumnOrName, path: str) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression)
+
+
+def json_tuple(col: ColumnOrName, *fields: str) -> Column:
+ return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields])
+
+
+def from_json(
+ col: ColumnOrName,
+ schema: t.Union[Column, str],
+ options: t.Optional[t.Dict[str, str]] = None,
+) -> Column:
+ schema = schema if isinstance(schema, Column) else lit(schema)
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col)
+ return Column.invoke_anonymous_function(col, "FROM_JSON", schema)
+
+
+def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "TO_JSON", options_col)
+ return Column.invoke_anonymous_function(col, "TO_JSON")
+
+
+def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col)
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON")
+
+
+def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col)
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV")
+
+
+def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "TO_CSV", options_col)
+ return Column.invoke_anonymous_function(col, "TO_CSV")
+
+
+def size(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArraySize)
+
+
+def array_min(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_MIN")
+
+
+def array_max(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_MAX")
+
+
+def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
+ if asc is not None:
+ return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc))
+ return Column.invoke_anonymous_function(col, "SORT_ARRAY")
+
+
+def array_sort(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArraySort)
+
+
+def shuffle(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SHUFFLE")
+
+
+def reverse(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "REVERSE")
+
+
+def flatten(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "FLATTEN")
+
+
+def map_keys(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_KEYS")
+
+
+def map_values(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_VALUES")
+
+
+def map_entries(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_ENTRIES")
+
+
+def map_from_entries(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES")
+
+
+def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
+ count_col = count if isinstance(count, Column) else lit(count)
+ return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col)
+
+
+def array_zip(*cols: ColumnOrName) -> Column:
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP")
+ return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:])
+
+
+def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ if len(columns) == 1:
+ return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT")
+ return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
+
+
+def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
+ if step is not None:
+ return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
+ return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
+
+
+def from_csv(
+ col: ColumnOrName,
+ schema: t.Union[Column, str],
+ options: t.Optional[t.Dict[str, str]] = None,
+) -> Column:
+ schema = schema if isinstance(schema, Column) else lit(schema)
+ if options is not None:
+ option_cols = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols)
+ return Column.invoke_anonymous_function(col, "FROM_CSV", schema)
+
+
+def aggregate(
+ col: ColumnOrName,
+ initialValue: ColumnOrName,
+ merge: t.Callable[[Column, Column], Column],
+ finish: t.Optional[t.Callable[[Column], Column]] = None,
+ accumulator_name: str = "acc",
+ target_row_name: str = "x",
+) -> Column:
+ merge_exp = glotexp.Lambda(
+ this=merge(Column(accumulator_name), Column(target_row_name)).expression,
+ expressions=[
+ glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)),
+ glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)),
+ ],
+ )
+ if finish is not None:
+ finish_exp = glotexp.Lambda(
+ this=finish(Column(accumulator_name)).expression,
+ expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))],
+ )
+ return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
+ return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
+
+
+def transform(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+ target_row_name: str = "x",
+ row_count_name: str = "i",
+) -> Column:
+ num_arguments = len(signature(f).parameters)
+ expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
+ columns = [Column(target_row_name)]
+ if num_arguments > 1:
+ columns.append(Column(row_count_name))
+ expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
+
+ f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
+
+
+def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(target_row_name)).expression,
+ expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
+ )
+ return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
+
+
+def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(target_row_name)).expression,
+ expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
+ )
+
+ return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
+
+
+def filter(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+ target_row_name: str = "x",
+ row_count_name: str = "i",
+) -> Column:
+ num_arguments = len(signature(f).parameters)
+ expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
+ columns = [Column(target_row_name)]
+ if num_arguments > 1:
+ columns.append(Column(row_count_name))
+ expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
+
+ f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression))
+
+
+def zip_with(
+ left: ColumnOrName,
+ right: ColumnOrName,
+ f: t.Callable[[Column, Column], Column],
+ left_name: str = "x",
+ right_name: str = "y",
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(left_name), Column(right_name)).expression,
+ expressions=[
+ glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)),
+ glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)),
+ ],
+ )
+
+ return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
+
+
+def transform_keys(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
+
+
+def transform_values(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
+
+
+def map_filter(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
+
+
+def map_zip_with(
+ col1: ColumnOrName,
+ col2: ColumnOrName,
+ f: t.Union[t.Callable[[Column, Column, Column], Column]],
+ key_name: str = "k",
+ value1: str = "v1",
+ value2: str = "v2",
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value1), Column(value2)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)),
+ glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
+
+
+def _lambda_quoted(value: str) -> t.Optional[bool]:
+ return False if value == "_" else None
diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py
new file mode 100644
index 0000000..947aace
--- /dev/null
+++ b/sqlglot/dataframe/sql/group.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.operations import Operation, operation
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+
+class GroupedData:
+ def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
+ self._df = df.copy()
+ self.spark = df.spark
+ self.last_op = last_op
+ self.group_by_cols = group_by_cols
+
+ def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
+ func_name = func_name.lower()
+ return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
+
+ @operation(Operation.SELECT)
+ def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
+ columns = (
+ [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
+ if isinstance(exprs[0], dict)
+ else exprs
+ )
+ cols = self._df._ensure_and_normalize_cols(columns)
+
+ expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
+ *[x.expression for x in self.group_by_cols + cols], append=False
+ )
+ return self._df.copy(expression=expression)
+
+ def count(self) -> DataFrame:
+ return self.agg(F.count("*").alias("count"))
+
+ def mean(self, *cols: str) -> DataFrame:
+ return self.avg(*cols)
+
+ def avg(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("avg", cols))
+
+ def max(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("max", cols))
+
+ def min(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("min", cols))
+
+ def sum(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("sum", cols))
+
+ def pivot(self, *cols: str) -> DataFrame:
+ raise NotImplementedError("Sum distinct is not currently implemented")
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
new file mode 100644
index 0000000..1513946
--- /dev/null
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.helper import ensure_list
+
+NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
+ expr = ensure_list(expr)
+ expressions = _ensure_expressions(expr)
+ for expression in expressions:
+ identifiers = expression.find_all(exp.Identifier)
+ for identifier in identifiers:
+ replace_alias_name_with_cte_name(spark, expression_context, identifier)
+ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
+
+
+def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+ if id.alias_or_name in spark.name_to_sequence_id_mapping:
+ for cte in reversed(expression_context.ctes):
+ if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
+ _set_alias_name(id, cte.alias_or_name)
+ break
+
+
+def replace_branch_and_sequence_ids_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
+ if id.alias_or_name in spark.known_ids:
+ # Check if we have a join and if both the tables in that join share a common branch id
+ # If so we need to have this reference the left table by default unless the id is a sequence
+ # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
+ # be common in practice
+ if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
+ join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
+ ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
+ assert len(ctes_in_join) == 2
+ _set_alias_name(id, ctes_in_join[0].alias_or_name)
+ return
+
+ for cte in reversed(expression_context.ctes):
+ if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
+ _set_alias_name(id, cte.alias_or_name)
+ return
+
+
+def _set_alias_name(id: exp.Identifier, name: str):
+ id.set("this", name)
+
+
+def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
+ values = ensure_list(values)
+ results = []
+ for value in values:
+ if isinstance(value, str):
+ results.append(Column.ensure_col(value).expression)
+ elif isinstance(value, Column):
+ results.append(value.expression)
+ elif isinstance(value, exp.Expression):
+ results.append(value)
+ else:
+ raise ValueError(f"Got an invalid type to normalize: {type(value)}")
+ return results
diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py
new file mode 100644
index 0000000..d51335c
--- /dev/null
+++ b/sqlglot/dataframe/sql/operations.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+import functools
+import typing as t
+from enum import IntEnum
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.group import GroupedData
+
+
+class Operation(IntEnum):
+ INIT = -1
+ NO_OP = 0
+ FROM = 1
+ WHERE = 2
+ GROUP_BY = 3
+ HAVING = 4
+ SELECT = 5
+ ORDER_BY = 6
+ LIMIT = 7
+
+
+def operation(op: Operation):
+ """
+ Decorator used around DataFrame methods to indicate what type of operation is being performed from the
+ ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
+ included with the previous operation.
+
+ Ex: After a user does a join we want to allow them to select which columns for the different
+ tables that they want to carry through to the following operation. If we put that join in
+ a CTE preemptively then the user would not have a chance to select which column they want
+ in cases where there is overlap in names.
+ """
+
+ def decorator(func: t.Callable):
+ @functools.wraps(func)
+ def wrapper(self: DataFrame, *args, **kwargs):
+ if self.last_op == Operation.INIT:
+ self = self._convert_leaf_to_cte()
+ self.last_op = Operation.NO_OP
+ last_op = self.last_op
+ new_op = op if op != Operation.NO_OP else last_op
+ if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
+ self = self._convert_leaf_to_cte()
+ df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
+ df.last_op = new_op # type: ignore
+ return df
+
+ wrapper.__wrapped__ = func # type: ignore
+ return wrapper
+
+ return decorator
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
new file mode 100644
index 0000000..4830035
--- /dev/null
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -0,0 +1,79 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.helper import object_to_dict
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+class DataFrameReader:
+ def __init__(self, spark: SparkSession):
+ self.spark = spark
+
+ def table(self, tableName: str) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ sqlglot.schema.add_table(tableName)
+ return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+
+
+class DataFrameWriter:
+ def __init__(
+ self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ ):
+ self._df = df
+ self._spark = spark or df.spark
+ self._mode = mode
+ self._by_name = by_name
+
+ def copy(self, **kwargs) -> DataFrameWriter:
+ return DataFrameWriter(
+ **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ )
+
+ def sql(self, **kwargs) -> t.List[str]:
+ return self._df.sql(**kwargs)
+
+ def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
+ return self.copy(_mode=saveMode)
+
+ @property
+ def byName(self):
+ return self.copy(by_name=True)
+
+ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ output_expression_container = exp.Insert(
+ **{
+ "this": exp.to_table(tableName),
+ "overwrite": overwrite,
+ }
+ )
+ df = self._df.copy(output_expression_container=output_expression_container)
+ if self._by_name:
+ columns = sqlglot.schema.column_names(tableName, only_visible=True)
+ df = df._convert_leaf_to_cte().select(*columns)
+
+ return self.copy(_df=df)
+
+ def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
+ if format is not None:
+ raise NotImplementedError("Providing Format in the save as table is not supported")
+ exists, replace, mode = None, None, mode or str(self._mode)
+ if mode == "append":
+ return self.insertInto(name)
+ if mode == "ignore":
+ exists = True
+ if mode == "overwrite":
+ replace = True
+ output_expression_container = exp.Create(
+ this=exp.to_table(name),
+ kind="TABLE",
+ exists=exists,
+ replace=replace,
+ )
+ return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
new file mode 100644
index 0000000..1ea86d1
--- /dev/null
+++ b/sqlglot/dataframe/sql/session.py
@@ -0,0 +1,148 @@
+from __future__ import annotations
+
+import typing as t
+import uuid
+from collections import defaultdict
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.dataframe import DataFrame
+from sqlglot.dataframe.sql.readwriter import DataFrameReader
+from sqlglot.dataframe.sql.types import StructType
+from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
+
+
+class SparkSession:
+ known_ids: t.ClassVar[t.Set[str]] = set()
+ known_branch_ids: t.ClassVar[t.Set[str]] = set()
+ known_sequence_ids: t.ClassVar[t.Set[str]] = set()
+ name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
+
+ def __init__(self):
+ self.incrementing_id = 1
+
+ def __getattr__(self, name: str) -> SparkSession:
+ return self
+
+ def __call__(self, *args, **kwargs) -> SparkSession:
+ return self
+
+ @property
+ def read(self) -> DataFrameReader:
+ return DataFrameReader(self)
+
+ def table(self, tableName: str) -> DataFrame:
+ return self.read.table(tableName)
+
+ def createDataFrame(
+ self,
+ data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
+ schema: t.Optional[SchemaInput] = None,
+ samplingRatio: t.Optional[float] = None,
+ verifySchema: bool = False,
+ ) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ if samplingRatio is not None or verifySchema:
+ raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
+ if schema is not None and (
+ not isinstance(schema, (StructType, str, list))
+ or (isinstance(schema, list) and not isinstance(schema[0], str))
+ ):
+ raise NotImplementedError("Only schema of either list or string of list supported")
+ if not data:
+ raise ValueError("Must provide data to create into a DataFrame")
+
+ column_mapping: t.Dict[str, t.Optional[str]]
+ if schema is not None:
+ column_mapping = get_column_mapping_from_schema_input(schema)
+ elif isinstance(data[0], dict):
+ column_mapping = {col_name.strip(): None for col_name in data[0]}
+ else:
+ column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
+
+ data_expressions = [
+ exp.Tuple(
+ expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
+ )
+ for row in data
+ ]
+
+ sel_columns = [
+ F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
+ for name, data_type in column_mapping.items()
+ ]
+
+ select_kwargs = {
+ "expressions": sel_columns,
+ "from": exp.From(
+ expressions=[
+ exp.Subquery(
+ this=exp.Values(expressions=data_expressions),
+ alias=exp.TableAlias(
+ this=exp.to_identifier(self._auto_incrementing_name),
+ columns=[exp.to_identifier(col_name) for col_name in column_mapping],
+ ),
+ )
+ ]
+ ),
+ }
+
+ sel_expression = exp.Select(**select_kwargs)
+ return DataFrame(self, sel_expression)
+
+ def sql(self, sqlQuery: str) -> DataFrame:
+ expression = sqlglot.parse_one(sqlQuery, read="spark")
+ if isinstance(expression, exp.Select):
+ df = DataFrame(self, expression)
+ df = df._convert_leaf_to_cte()
+ elif isinstance(expression, (exp.Create, exp.Insert)):
+ select_expression = expression.expression.copy()
+ if isinstance(expression, exp.Insert):
+ select_expression.set("with", expression.args.get("with"))
+ expression.set("with", None)
+ del expression.args["expression"]
+ df = DataFrame(self, select_expression, output_expression_container=expression)
+ df = df._convert_leaf_to_cte()
+ else:
+ raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
+ return df
+
+ @property
+ def _auto_incrementing_name(self) -> str:
+ name = f"a{self.incrementing_id}"
+ self.incrementing_id += 1
+ return name
+
+ @property
+ def _random_name(self) -> str:
+ return f"a{str(uuid.uuid4())[:8]}"
+
+ @property
+ def _random_branch_id(self) -> str:
+ id = self._random_id
+ self.known_branch_ids.add(id)
+ return id
+
+ @property
+ def _random_sequence_id(self):
+ id = self._random_id
+ self.known_sequence_ids.add(id)
+ return id
+
+ @property
+ def _random_id(self) -> str:
+ id = f"a{str(uuid.uuid4())[:8]}"
+ self.known_ids.add(id)
+ return id
+
+ @property
+ def _join_hint_names(self) -> t.Set[str]:
+ return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
+
+ def _add_alias_to_mapping(self, name: str, sequence_id: str):
+ self.name_to_sequence_id_mapping[name].append(sequence_id)
diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py
new file mode 100644
index 0000000..b3dcc12
--- /dev/null
+++ b/sqlglot/dataframe/sql/transforms.py
@@ -0,0 +1,9 @@
+import typing as t
+
+from sqlglot import expressions as exp
+
+
+def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
+ if isinstance(node, exp.Identifier) and node in replacement_mapping:
+ node = node.replace(replacement_mapping[node].copy())
+ return node
diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py
new file mode 100644
index 0000000..dc5c05a
--- /dev/null
+++ b/sqlglot/dataframe/sql/types.py
@@ -0,0 +1,208 @@
+import typing as t
+
+
+class DataType:
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + "()"
+
+ def __hash__(self) -> int:
+ return hash(str(self))
+
+ def __eq__(self, other: t.Any) -> bool:
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other: t.Any) -> bool:
+ return not self.__eq__(other)
+
+ def __str__(self) -> str:
+ return self.typeName()
+
+ @classmethod
+ def typeName(cls) -> str:
+ return cls.__name__[:-4].lower()
+
+ def simpleString(self) -> str:
+ return str(self)
+
+ def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
+ return str(self)
+
+
+class DataTypeWithLength(DataType):
+ def __init__(self, length: int):
+ self.length = length
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.length})"
+
+ def __str__(self) -> str:
+ return f"{self.typeName()}({self.length})"
+
+
+class StringType(DataType):
+ pass
+
+
+class CharType(DataTypeWithLength):
+ pass
+
+
+class VarcharType(DataTypeWithLength):
+ pass
+
+
+class BinaryType(DataType):
+ pass
+
+
+class BooleanType(DataType):
+ pass
+
+
+class DateType(DataType):
+ pass
+
+
+class TimestampType(DataType):
+ pass
+
+
+class TimestampNTZType(DataType):
+ @classmethod
+ def typeName(cls) -> str:
+ return "timestamp_ntz"
+
+
+class DecimalType(DataType):
+ def __init__(self, precision: int = 10, scale: int = 0):
+ self.precision = precision
+ self.scale = scale
+
+ def simpleString(self) -> str:
+ return f"decimal({self.precision}, {self.scale})"
+
+ def jsonValue(self) -> str:
+ return f"decimal({self.precision}, {self.scale})"
+
+ def __repr__(self) -> str:
+ return f"DecimalType({self.precision}, {self.scale})"
+
+
+class DoubleType(DataType):
+ pass
+
+
+class FloatType(DataType):
+ pass
+
+
+class ByteType(DataType):
+ def __str__(self) -> str:
+ return "tinyint"
+
+
+class IntegerType(DataType):
+ def __str__(self) -> str:
+ return "int"
+
+
+class LongType(DataType):
+ def __str__(self) -> str:
+ return "bigint"
+
+
+class ShortType(DataType):
+ def __str__(self) -> str:
+ return "smallint"
+
+
+class ArrayType(DataType):
+ def __init__(self, elementType: DataType, containsNull: bool = True):
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self) -> str:
+ return f"ArrayType({self.elementType, str(self.containsNull)}"
+
+ def simpleString(self) -> str:
+ return f"array<{self.elementType.simpleString()}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull,
+ }
+
+
+class MapType(DataType):
+ def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self) -> str:
+ return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
+
+ def simpleString(self) -> str:
+ return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull,
+ }
+
+
+class StructField(DataType):
+ def __init__(
+ self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
+ ):
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+ self.metadata = metadata or {}
+
+ def __repr__(self) -> str:
+ return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
+
+ def simpleString(self) -> str:
+ return f"{self.name}:{self.dataType.simpleString()}"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable,
+ "metadata": self.metadata,
+ }
+
+
+class StructType(DataType):
+ def __init__(self, fields: t.Optional[t.List[StructField]] = None):
+ if not fields:
+ self.fields = []
+ self.names = []
+ else:
+ self.fields = fields
+ self.names = [f.name for f in fields]
+
+ def __iter__(self) -> t.Iterator[StructField]:
+ return iter(self.fields)
+
+ def __len__(self) -> int:
+ return len(self.fields)
+
+ def __repr__(self) -> str:
+ return f"StructType({', '.join(str(field) for field in self)})"
+
+ def simpleString(self) -> str:
+ return f"struct<{', '.join(x.simpleString() for x in self)}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
+
+ def fieldNames(self) -> t.List[str]:
+ return list(self.names)
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
new file mode 100644
index 0000000..575d18a
--- /dev/null
+++ b/sqlglot/dataframe/sql/util.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import types
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import SchemaInput
+
+
+def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
+ if isinstance(schema, dict):
+ return schema
+ elif isinstance(schema, str):
+ col_name_type_strs = [x.strip() for x in schema.split(",")]
+ return {
+ name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
+ for name_type_str in col_name_type_strs
+ }
+ elif isinstance(schema, types.StructType):
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
+ return {x.strip(): None for x in schema} # type: ignore
+
+
+def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
+ if not expression.args.get("joins"):
+ return []
+
+ left_table = expression.args["from"].args["expressions"][0]
+ other_tables = [join.this for join in expression.args["joins"]]
+ return [left_table] + other_tables
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
new file mode 100644
index 0000000..842f366
--- /dev/null
+++ b/sqlglot/dataframe/sql/window.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+import sys
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.helper import flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrName
+
+
+class Window:
+ _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
+ _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
+ _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
+ _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
+
+ unboundedPreceding: int = _JAVA_MIN_LONG
+
+ unboundedFollowing: int = _JAVA_MAX_LONG
+
+ currentRow: int = 0
+
+ @classmethod
+ def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ return WindowSpec().partitionBy(*cols)
+
+ @classmethod
+ def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ return WindowSpec().orderBy(*cols)
+
+ @classmethod
+ def rowsBetween(cls, start: int, end: int) -> WindowSpec:
+ return WindowSpec().rowsBetween(start, end)
+
+ @classmethod
+ def rangeBetween(cls, start: int, end: int) -> WindowSpec:
+ return WindowSpec().rangeBetween(start, end)
+
+
+class WindowSpec:
+ def __init__(self, expression: exp.Expression = exp.Window()):
+ self.expression = expression
+
+ def copy(self):
+ return WindowSpec(self.expression.copy())
+
+ def sql(self, **kwargs) -> str:
+ return self.expression.sql(dialect="spark", **kwargs)
+
+ def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ from sqlglot.dataframe.sql.column import Column
+
+ cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [Column.ensure_col(x).expression for x in cols]
+ window_spec = self.copy()
+ partition_by_expressions = window_spec.expression.args.get("partition_by", [])
+ partition_by_expressions.extend(expressions)
+ window_spec.expression.set("partition_by", partition_by_expressions)
+ return window_spec
+
+ def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ from sqlglot.dataframe.sql.column import Column
+
+ cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [Column.ensure_col(x).expression for x in cols]
+ window_spec = self.copy()
+ if window_spec.expression.args.get("order") is None:
+ window_spec.expression.set("order", exp.Order(expressions=[]))
+ order_by = window_spec.expression.args["order"].expressions
+ order_by.extend(expressions)
+ window_spec.expression.args["order"].set("expressions", order_by)
+ return window_spec
+
+ def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
+ kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
+ if start == Window.currentRow:
+ kwargs["start"] = "CURRENT ROW"
+ else:
+ kwargs = {
+ **kwargs,
+ **{
+ "start_side": "PRECEDING",
+ "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
+ },
+ }
+ if end == Window.currentRow:
+ kwargs["end"] = "CURRENT ROW"
+ else:
+ kwargs = {
+ **kwargs,
+ **{
+ "end_side": "FOLLOWING",
+ "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
+ },
+ }
+ return kwargs
+
+ def rowsBetween(self, start: int, end: int) -> WindowSpec:
+ window_spec = self.copy()
+ spec = self._calc_start_end(start, end)
+ spec["kind"] = "ROWS"
+ window_spec.expression.set(
+ "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ )
+ return window_spec
+
+ def rangeBetween(self, start: int, end: int) -> WindowSpec:
+ window_spec = self.copy()
+ spec = self._calc_start_end(start, end)
+ spec["kind"] = "RANGE"
+ window_spec.expression.set(
+ "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ )
+ return window_spec
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 86e46cf..62d042e 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -78,6 +78,16 @@ def _create_sql(self, expression):
class BigQuery(Dialect):
unnest_column_only = True
+ time_mapping = {
+ "%M": "%-M",
+ "%d": "%-d",
+ "%m": "%-m",
+ "%y": "%-y",
+ "%H": "%-H",
+ "%I": "%-I",
+ "%S": "%-S",
+ "%j": "%-j",
+ }
class Tokenizer(Tokenizer):
QUOTES = [
@@ -113,6 +123,7 @@ class BigQuery(Dialect):
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
+ "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
}
NO_PAREN_FUNCTIONS = {
@@ -137,6 +148,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
+ exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 531c72a..46661cf 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -2,7 +2,7 @@ from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
-from sqlglot.helper import list_get
+from sqlglot.helper import flatten, list_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@@ -67,6 +67,11 @@ class _Dialect(type):
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
+ if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
+ be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
+ klass.generator_class.TRANSFORMS[
+ exp.ByteString
+ ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
return klass
@@ -176,11 +181,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
def _rename(self, expression):
- args = (
- expression.expressions
- if isinstance(expression, exp.Func) and expression.is_var_len_args
- else expression.args.values()
- )
+ args = flatten(expression.args.values())
return f"{name}({self.format_args(*args)})"
return _rename
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 8888df8..0810e0c 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -121,6 +121,9 @@ class Hive(Dialect):
"ss": "%S",
"s": "%-S",
"S": "%f",
+ "a": "%p",
+ "DD": "%j",
+ "D": "%-j",
}
date_format = "'yyyy-MM-dd'"
@@ -200,6 +203,7 @@ class Hive(Dialect):
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
+ exp.ArrayConcat: rename_func("CONCAT"),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
exp.With: no_recursive_cte_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 8449379..524390f 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -97,6 +97,8 @@ class MySQL(Dialect):
"%s": "%S",
"%S": "%S",
"%u": "%W",
+ "%k": "%-H",
+ "%l": "%-I",
}
class Tokenizer(Tokenizer):
@@ -145,6 +147,9 @@ class MySQL(Dialect):
"_TIS620": TokenType.INTRODUCER,
"_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER,
+ # https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
+ "N": TokenType.INTRODUCER,
+ "n": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER,
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 8041ff0..144dba5 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -80,17 +80,12 @@ class Oracle(Dialect):
sep="",
)
- def alias_sql(self, expression):
- if isinstance(expression.this, exp.Table):
- to_sql = self.sql(expression, "alias")
- # oracle does not allow "AS" between table and alias
- to_sql = f" {to_sql}" if to_sql else ""
- return f"{self.sql(expression, 'this')}{to_sql}"
- return super().alias_sql(expression)
-
def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS"
+ def table_sql(self, expression):
+ return super().table_sql(expression, sep=" ")
+
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c91ff4b..459e926 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -163,6 +163,7 @@ class Postgres(Dialect):
class Tokenizer(Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
+ BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
@@ -176,6 +177,11 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
}
+ QUOTES = ["'", "$$"]
+ SINGLE_TOKENS = {
+ **Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.PARAMETER,
+ }
class Parser(Parser):
STRICT_CAST = False
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 8dfb2fd..41c0db1 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -172,6 +172,7 @@ class Presto(Dialect):
**transforms.UNALIAS_GROUP,
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
+ exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 19a427c..627258f 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -69,6 +69,35 @@ def _unix_to_time(self, expression):
raise ValueError("Improper scale for timestamp")
+# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
+# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
+def _parse_date_part(self):
+ this = self._parse_var() or self._parse_type()
+ self._match(TokenType.COMMA)
+ expression = self._parse_bitwise()
+
+ name = this.name.upper()
+ if name.startswith("EPOCH"):
+ if name.startswith("EPOCH_MILLISECOND"):
+ scale = 10**3
+ elif name.startswith("EPOCH_MICROSECOND"):
+ scale = 10**6
+ elif name.startswith("EPOCH_NANOSECOND"):
+ scale = 10**9
+ else:
+ scale = None
+
+ ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
+ to_unix = self.expression(exp.TimeToUnix, this=ts)
+
+ if scale:
+ to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
+
+ return to_unix
+
+ return self.expression(exp.Extract, this=this, expression=expression)
+
+
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@@ -115,7 +144,7 @@ class Snowflake(Dialect):
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
- "DATE_PART": lambda self: self._parse_extract(),
+ "DATE_PART": _parse_date_part,
}
FUNC_TOKENS = {
@@ -161,9 +190,11 @@ class Snowflake(Dialect):
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
+ exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
+ exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 95a7ab4..6bf4ff0 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -1,9 +1,5 @@
from sqlglot import exp
-from sqlglot.dialects.dialect import (
- create_with_partitions_sql,
- no_ilike_sql,
- rename_func,
-)
+from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get
from sqlglot.parser import Parser
@@ -98,13 +94,14 @@ class Spark(Hive):
}
TRANSFORMS = {
- **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
+ **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
+ exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
+ exp.DateTrunc: rename_func("TRUNC"),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
- exp.ILike: no_ilike_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
@@ -112,6 +109,8 @@ class Spark(Hive):
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
+ exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
+ exp.VariancePop: rename_func("VAR_POP"),
}
WRAP_DERIVED_VALUES = False
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 73b232e..1f2e50d 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -32,6 +32,11 @@ class TSQL(Dialect):
}
class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "CHARINDEX": exp.StrPosition.from_arg_list,
+ }
+
def _parse_convert(self):
to = self._parse_types()
self._match(TokenType.COMMA)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 72b0558..9c49dd1 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -19,6 +19,7 @@ ENV = {
"datetime": datetime,
"locals": locals,
"re": re,
+ "bool": bool,
"float": float,
"int": int,
"str": str,
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 8ef6cf0..fcb016b 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -80,9 +80,10 @@ class PythonExecutor:
source = step.source
if isinstance(source, exp.Expression):
- source = source.this.name or source.alias
+ source = source.name or source.alias
else:
source = step.name
+
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
@@ -121,7 +122,7 @@ class PythonExecutor:
source = step.source
alias = source.alias
- with csv_reader(source.this) as reader:
+ with csv_reader(source) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
@@ -308,7 +309,7 @@ def _interval_py(self, expression):
def _like_py(self, expression):
this = self.sql(expression, "this")
expression = self.sql(expression, "expression")
- return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})"""
+ return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
def _ordered_py(self, expression):
@@ -330,6 +331,7 @@ class Python(Dialect):
exp.Cast: _cast_py,
exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="),
+ exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Like: _like_py,
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 39f4452..f7717c8 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -11,6 +11,7 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_list,
list_get,
+ split_num_words,
subclasses,
)
@@ -108,6 +109,8 @@ class Expression(metaclass=_Expression):
@property
def alias_or_name(self):
+ if isinstance(self, Null):
+ return "NULL"
return self.alias or self.name
def __deepcopy__(self, memo):
@@ -659,6 +662,10 @@ class HexString(Condition):
pass
+class ByteString(Condition):
+ pass
+
+
class Column(Condition):
arg_types = {"this": True, "table": False}
@@ -725,7 +732,7 @@ class Constraint(Expression):
class Delete(Expression):
- arg_types = {"with": False, "this": True, "where": False}
+ arg_types = {"with": False, "this": True, "using": False, "where": False}
class Drop(Expression):
@@ -1192,6 +1199,7 @@ QUERY_MODIFIERS = {
class Table(Expression):
arg_types = {
"this": True,
+ "alias": False,
"db": False,
"catalog": False,
"laterals": False,
@@ -1323,6 +1331,7 @@ class Select(Subqueryable):
*expressions (str or Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Group`.
+ If nothing is passed in then a group by is not applied to the expression
append (bool): if `True`, add to any existing expressions.
Otherwise, this flattens all the `Group` expression into a single expression.
dialect (str): the dialect used to parse the input expression.
@@ -1332,6 +1341,8 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
+ if not expressions:
+ return self if not copy else self.copy()
return _apply_child_list_builder(
*expressions,
instance=self,
@@ -2239,6 +2250,11 @@ class ArrayAny(Func):
arg_types = {"this": True, "expression": True}
+class ArrayConcat(Func):
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
+
+
class ArrayContains(Func):
arg_types = {"this": True, "expression": True}
@@ -2570,7 +2586,7 @@ class SortArray(Func):
class Split(Func):
- arg_types = {"this": True, "expression": True}
+ arg_types = {"this": True, "expression": True, "limit": False}
# Start may be omitted in the case of postgres
@@ -3209,29 +3225,49 @@ def to_identifier(alias, quoted=None):
return identifier
-def to_table(sql_path, **kwargs):
+def to_table(sql_path: str, **kwargs) -> Table:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
- Example:
- >>> to_table('catalog.db.table_name').sql()
- 'catalog.db.table_name'
+
+ If a table is passed in then that table is returned.
Args:
- sql_path(str): `[catalog].[schema].[table]` string
+ sql_path(str|Table): `[catalog].[schema].[table]` string
Returns:
Table: A table expression
"""
- table_parts = sql_path.split(".")
- catalog, db, table_name = [
- to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts
- ]
+ if sql_path is None or isinstance(sql_path, Table):
+ return sql_path
+ if not isinstance(sql_path, str):
+ raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
+
+ catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
+def to_column(sql_path: str, **kwargs) -> Column:
+ """
+ Create a column from a `[table].[column]` sql path. Schema is optional.
+
+ If a column is passed in then that column is returned.
+
+ Args:
+ sql_path: `[table].[column]` string
+ Returns:
+ Table: A column expression
+ """
+ if sql_path is None or isinstance(sql_path, Column):
+ return sql_path
+ if not isinstance(sql_path, str):
+ raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
+ table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
+ return Column(this=column_name, table=table_name, **kwargs)
+
+
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
Create an Alias expression.
- Expample:
+ Example:
>>> alias_('foo', 'bar').sql()
'foo AS bar'
@@ -3249,7 +3285,16 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
exp = maybe_parse(expression, dialect=dialect, **opts)
alias = to_identifier(alias, quoted=quoted)
- alias = TableAlias(this=alias) if table else alias
+
+ if table:
+ expression.set("alias", TableAlias(this=alias))
+ return expression
+
+ # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
+ # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
+ # for the complete Window expression.
+ #
+ # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls
if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy()
@@ -3295,7 +3340,7 @@ def column(col, table=None, quoted=None):
)
-def table_(table, db=None, catalog=None, quoted=None):
+def table_(table, db=None, catalog=None, quoted=None, alias=None):
"""Build a Table.
Args:
@@ -3310,6 +3355,7 @@ def table_(table, db=None, catalog=None, quoted=None):
this=to_identifier(table, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
+ alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
@@ -3453,7 +3499,7 @@ def replace_tables(expression, mapping):
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
- 'SELECT * FROM "c"'
+ 'SELECT * FROM c'
Returns:
The mapped expression
@@ -3463,7 +3509,10 @@ def replace_tables(expression, mapping):
if isinstance(node, Table):
new_name = mapping.get(table_name(node))
if new_name:
- return table_(*reversed(new_name.split(".")), quoted=True)
+ return to_table(
+ new_name,
+ **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
+ )
return node
return expression.transform(_replace_tables)
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index bb7fd71..6decd16 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -47,6 +47,8 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
+ annotations: Whether or not to show annotations in the SQL.
+ Default: True
"""
TRANSFORMS = {
@@ -116,6 +118,7 @@ class Generator:
"_escaped_quote_end",
"_leading_comma",
"_max_text_width",
+ "_annotations",
)
def __init__(
@@ -141,6 +144,7 @@ class Generator:
max_unsupported=3,
leading_comma=False,
max_text_width=80,
+ annotations=True,
):
import sqlglot
@@ -169,6 +173,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
+ self._annotations = annotations
def generate(self, expression):
"""
@@ -275,7 +280,9 @@ class Generator:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
- return f"{self.sql(expression, 'expression')} # {expression.name.strip()}"
+ if self._annotations:
+ return f"{self.sql(expression, 'expression')} # {expression.name}"
+ return self.sql(expression, "expression")
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@@ -423,8 +430,11 @@ class Generator:
def delete_sql(self, expression):
this = self.sql(expression, "this")
+ using_sql = (
+ f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
+ )
where_sql = self.sql(expression, "where")
- sql = f"DELETE FROM {this}{where_sql}"
+ sql = f"DELETE FROM {this}{using_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
def drop_sql(self, expression):
@@ -571,7 +581,7 @@ class Generator:
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
- def table_sql(self, expression):
+ def table_sql(self, expression, sep=" AS "):
table = ".".join(
part
for part in [
@@ -582,13 +592,20 @@ class Generator:
if part
)
+ alias = self.sql(expression, "alias")
+ alias = f"{sep}{alias}" if alias else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
- return f"{table}{laterals}{joins}{pivots}"
+
+ if alias and pivots:
+ pivots = f"{pivots}{alias}"
+ alias = ""
+
+ return f"{table}{alias}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression):
- if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
+ if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
else:
@@ -1188,7 +1205,7 @@ class Generator:
if isinstance(arg_value, list):
for value in arg_value:
args.append(value)
- elif arg_value:
+ else:
args.append(arg_value)
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index c4dd91e..c3a23d3 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -2,7 +2,9 @@ import inspect
import logging
import re
import sys
+import typing as t
from contextlib import contextmanager
+from copy import copy
from enum import Enum
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
@@ -162,3 +164,54 @@ def find_new_name(taken, base):
i += 1
new = f"{base}_{i}"
return new
+
+
+def object_to_dict(obj, **kwargs):
+ return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
+
+
+def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
+ """
+ Perform a split on a value and return N words as a result with None used for words that don't exist.
+
+ Args:
+ value: The value to be split
+ sep: The value to use to split on
+ min_num_words: The minimum number of words that are going to be in the result
+ fill_from_start: Indicates that if None values should be inserted at the start or end of the list
+
+ Examples:
+ >>> split_num_words("db.table", ".", 3)
+ [None, 'db', 'table']
+ >>> split_num_words("db.table", ".", 3, fill_from_start=False)
+ ['db', 'table', None]
+ >>> split_num_words("db.table", ".", 1)
+ ['db', 'table']
+ """
+ words = value.split(sep)
+ if fill_from_start:
+ return [None] * (min_num_words - len(words)) + words
+ return words + [None] * (min_num_words - len(words))
+
+
+def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
+ """
+ Flattens a list that can contain both iterables and non-iterable elements
+
+ Examples:
+ >>> list(flatten([[1, 2], 3]))
+ [1, 2, 3]
+ >>> list(flatten([1, 2, 3]))
+ [1, 2, 3]
+
+ Args:
+ values: The value to be flattened
+
+ Returns:
+ Yields non-iterable elements (not including str or byte as iterable)
+ """
+ for value in values:
+ if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
+ yield from flatten(value)
+ else:
+ yield value
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index d1146ca..bba0878 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1,2 +1 @@
from sqlglot.optimizer.optimizer import RULES, optimize
-from sqlglot.optimizer.schema import Schema
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index a2cef37..30055bc 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,7 +1,7 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
-from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import ensure_schema
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 44cdc94..e30c263 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
- if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
+ if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index e060739..652cdef 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -12,18 +12,16 @@ def isolate_table_selects(expression):
if not isinstance(source, exp.Table):
continue
- if not isinstance(source.parent, exp.Alias):
+ if not source.alias:
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
- parent = source.parent
-
- parent.replace(
+ source.replace(
exp.select("*")
.from_(
- alias(source, source.name or parent.alias, table=True),
+ alias(source.copy(), source.name or source.alias, table=True),
copy=False,
)
- .subquery(parent.alias, copy=False)
+ .subquery(source.alias, copy=False)
)
return expression
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 3c51c18..70e4629 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False):
inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
- node_to_replace = table
- if isinstance(node_to_replace.parent, exp.Alias):
- node_to_replace = node_to_replace.parent
- alias = node_to_replace.alias
- else:
- alias = table.name
+ alias = table.alias_or_name
_rename_inner_sources(outer_scope, inner_scope, alias)
- _merge_from(outer_scope, inner_scope, node_to_replace, alias)
+ _merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
@@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
if isinstance(source, exp.Subquery):
source.set("alias", exp.TableAlias(this=new_alias))
- elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
- source.parent.set("alias", new_alias)
+ elif isinstance(source, exp.Table) and source.alias:
+ source.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source.copy(), new_alias))
@@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
tables = join_hint.find_all(exp.Table)
for table in tables:
if table.alias_or_name == node_to_replace.alias_or_name:
- new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
- table.set("this", exp.to_identifier(new_table.alias_or_name))
+ table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 2c28ab8..5ad8f46 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,3 +1,4 @@
+import sqlglot
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
+ If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (list): sequence of optimizer rules to use
@@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
Returns:
sqlglot.Expression: optimized expression
"""
- possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
+ possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
expression = expression.copy()
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
-
expression = rule(expression, **rule_kwargs)
return expression
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 5584830..5820851 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
+# SELECTION TO USE IF SELECTION LIST IS EMPTY
+DEFAULT_SELECTION = alias("1", "_")
+
def pushdown_projections(expression):
"""
@@ -25,7 +28,8 @@ def pushdown_projections(expression):
"""
# Map of Scope to all columns being selected by outer queries.
referenced_columns = defaultdict(set)
-
+ left_union = None
+ right_union = None
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
@@ -37,12 +41,16 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
- left, right = scope.union_scopes
- referenced_columns[left] = parent_selections
- referenced_columns[right] = parent_selections
+ left_union, right_union = scope.union_scopes
+ referenced_columns[left_union] = parent_selections
+ referenced_columns[right_union] = parent_selections
- if isinstance(scope.expression, exp.Select):
- _remove_unused_selections(scope, parent_selections)
+ if isinstance(scope.expression, exp.Select) and scope != right_union:
+ removed_indexes = _remove_unused_selections(scope, parent_selections)
+ # The left union is used for column names to select and if we remove columns from the left
+ # we need to also remove those same columns in the right that were at the same position
+ if scope is left_union:
+ _remove_indexed_selections(right_union, removed_indexes)
# Group columns by source name
selects = defaultdict(set)
@@ -61,6 +69,7 @@ def pushdown_projections(expression):
def _remove_unused_selections(scope, parent_selections):
+ removed_indexes = []
order = scope.expression.args.get("order")
if order:
@@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
- for selection in scope.selects:
+ for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
or selection.alias_or_name in parent_selections
or selection.alias_or_name in order_refs
):
new_selections.append(selection)
+ else:
+ removed_indexes.append(i)
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(alias("1", "_"))
+ new_selections.append(DEFAULT_SELECTION)
+
+ scope.expression.set("expressions", new_selections)
+ return removed_indexes
+
+def _remove_indexed_selections(scope, indexes_to_remove):
+ new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
+ if not new_selections:
+ new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 7d77ef1..36ba028 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -2,8 +2,8 @@ import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
-from sqlglot.optimizer.schema import ensure_schema
-from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import ensure_schema
def qualify_columns(expression, schema):
@@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
- if isinstance(derived_table, exp.UDTF):
+ if isinstance(derived_table.unnest(), exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver):
if column_table:
column.set("table", exp.to_identifier(column_table))
+ # Determine whether each reference in the order by clause is to a column or an alias.
+ for ordered in scope.find_all(exp.Ordered):
+ for column in ordered.find_all(exp.Column):
+ column_table = column.table
+ column_name = column.name
+
+ if column_table or column.parent is ordered or column_name not in resolver.all_columns:
+ continue
+
+ column_table = resolver.get_table(column_name)
+
+ if column_table is None:
+ raise OptimizeError(f"Ambiguous column: {column_name}")
+
+ column.set("table", exp.to_identifier(column_table))
+
def _expand_stars(scope, resolver):
"""Expand stars to lists of column selections"""
@@ -346,6 +362,11 @@ class _Resolver:
except Exception as e:
raise OptimizeError(str(e)) from e
+ if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
+ values_alias = source.expression.parent
+ if hasattr(values_alias, "alias_column_names"):
+ return values_alias.alias_column_names
+
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 30e93ba..0e467d3 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None):
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
- if not isinstance(source.parent, exp.Alias):
+ if not source.alias:
source.replace(
alias(
source.copy(),
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
deleted file mode 100644
index d7743c9..0000000
--- a/sqlglot/optimizer/schema.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import abc
-
-from sqlglot import exp
-from sqlglot.errors import OptimizeError
-from sqlglot.helper import csv_reader
-
-
-class Schema(abc.ABC):
- """Abstract base class for database schemas"""
-
- @abc.abstractmethod
- def column_names(self, table, only_visible=False):
- """
- Get the column names for a table.
- Args:
- table (sqlglot.expressions.Table): Table expression instance
- only_visible (bool): Whether to include invisible columns
- Returns:
- list[str]: list of column names
- """
-
- @abc.abstractmethod
- def get_column_type(self, table, column):
- """
- Get the exp.DataType type of a column in the schema.
-
- Args:
- table (sqlglot.expressions.Table): The source table.
- column (sqlglot.expressions.Column): The target column.
- Returns:
- sqlglot.expressions.DataType.Type: The resulting column type.
- """
-
-
-class MappingSchema(Schema):
- """
- Schema based on a nested mapping.
-
- Args:
- schema (dict): Mapping in one of the following forms:
- 1. {table: {col: type}}
- 2. {db: {table: {col: type}}}
- 3. {catalog: {db: {table: {col: type}}}}
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
- are assumed to be visible. The nesting should mirror that of the schema:
- 1. {table: set(*cols)}}
- 2. {db: {table: set(*cols)}}}
- 3. {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
- """
-
- def __init__(self, schema, visible=None, dialect=None):
- self.schema = schema
- self.visible = visible
- self.dialect = dialect
- self._type_mapping_cache = {}
-
- depth = _dict_depth(schema)
-
- if not depth: # {}
- self.supported_table_args = []
- elif depth == 2: # {table: {col: type}}
- self.supported_table_args = ("this",)
- elif depth == 3: # {db: {table: {col: type}}}
- self.supported_table_args = ("db", "this")
- elif depth == 4: # {catalog: {db: {table: {col: type}}}}
- self.supported_table_args = ("catalog", "db", "this")
- else:
- raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
-
- self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
-
- def column_names(self, table, only_visible=False):
- if not isinstance(table.this, exp.Identifier):
- return fs_get(table)
-
- args = tuple(table.text(p) for p in self.supported_table_args)
-
- for forbidden in self.forbidden_args:
- if table.text(forbidden):
- raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
-
- columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
- if not only_visible or not self.visible:
- return columns
-
- visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
- return [col for col in columns if col in visible]
-
- def get_column_type(self, table, column):
- try:
- schema_type = self.schema.get(table.name, {}).get(column.name).upper()
- return self._convert_type(schema_type)
- except:
- raise OptimizeError(f"Failed to get type for column {column.sql()}")
-
- def _convert_type(self, schema_type):
- """
- Convert a type represented as a string to the corresponding exp.DataType.Type object.
-
- Args:
- schema_type (str): The type we want to convert.
- Returns:
- sqlglot.expressions.DataType.Type: The resulting expression type.
- """
- if schema_type not in self._type_mapping_cache:
- try:
- self._type_mapping_cache[schema_type] = exp.maybe_parse(
- schema_type, into=exp.DataType, dialect=self.dialect
- ).this
- except AttributeError:
- raise OptimizeError(f"Failed to convert type {schema_type}")
-
- return self._type_mapping_cache[schema_type]
-
-
-def ensure_schema(schema):
- if isinstance(schema, Schema):
- return schema
-
- return MappingSchema(schema)
-
-
-def fs_get(table):
- name = table.this.name
-
- if name.upper() == "READ_CSV":
- with csv_reader(table) as reader:
- return next(reader)
-
- raise ValueError(f"Cannot read schema for {table}")
-
-
-def _nested_get(d, *path):
- """
- Get a value for a nested dictionary.
-
- Args:
- d (dict): dictionary
- *path (tuple[str, str]): tuples of (name, key)
- `key` is the key in the dictionary to get.
- `name` is a string to use in the error if `key` isn't found.
- """
- for name, key in path:
- d = d.get(key)
- if d is None:
- name = "table" if name == "this" else name
- raise ValueError(f"Unknown {name}")
- return d
-
-
-def _dict_depth(d):
- """
- Get the nesting depth of a dictionary.
-
- For example:
- >>> _dict_depth(None)
- 0
- >>> _dict_depth({})
- 1
- >>> _dict_depth({"a": "b"})
- 1
- >>> _dict_depth({"a": {}})
- 2
- >>> _dict_depth({"a": {"b": {}}})
- 3
-
- Args:
- d (dict): dictionary
- Returns:
- int: depth
- """
- try:
- return 1 + _dict_depth(next(iter(d.values())))
- except AttributeError:
- # d doesn't have attribute "values"
- return 0
- except StopIteration:
- # d.values() returns an empty sequence
- return 1
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 68298a0..b7eb6c2 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -257,12 +257,7 @@ class Scope:
referenced_names = []
for table in self.tables:
- referenced_names.append(
- (
- table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
- table,
- )
- )
+ referenced_names.append((table.alias_or_name, table))
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))
@@ -538,8 +533,8 @@ def _add_table_sources(scope):
for table in scope.tables:
table_name = table.name
- if isinstance(table.parent, exp.Alias):
- source_name = table.parent.alias
+ if table.alias:
+ source_name = table.alias
else:
source_name = table_name
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index b378f12..47c1c1d 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -329,6 +329,7 @@ class Parser:
exp.DataType: lambda self: self._parse_types(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
+ exp.Identifier: lambda self: self._parse_id_var(),
exp.Lateral: lambda self: self._parse_lateral(),
exp.Join: lambda self: self._parse_join(),
exp.Order: lambda self: self._parse_order(),
@@ -371,11 +372,8 @@ class Parser:
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
- TokenType.INTRODUCER: lambda self, token: self.expression(
- exp.Introducer,
- this=token.text,
- expression=self._parse_var_or_string(),
- ),
+ TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
+ TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
}
RANGE_PARSERS = {
@@ -500,7 +498,7 @@ class Parser:
max_errors=3,
null_ordering=None,
):
- self.error_level = error_level or ErrorLevel.RAISE
+ self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
@@ -928,6 +926,7 @@ class Parser:
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
+ using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
where=self._parse_where(),
)
@@ -1148,7 +1147,7 @@ class Parser:
def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION):
- return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
+ return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
return expression
@@ -1277,7 +1276,7 @@ class Parser:
alias = self._parse_table_alias()
if alias:
- this = self.expression(exp.Alias, this=this, alias=alias)
+ this.set("alias", alias)
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@@ -1876,6 +1875,17 @@ class Parser:
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+ def _parse_introducer(self, token):
+ literal = self._parse_primary()
+ if literal:
+ return self.expression(
+ exp.Introducer,
+ this=token.text,
+ expression=literal,
+ )
+
+ return self.expression(exp.Identifier, this=token.text)
+
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index efabc15..ea995d8 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -199,13 +199,14 @@ class Step:
class Scan(Step):
@classmethod
def from_expression(cls, expression, ctes=None):
- table = expression.this
+ table = expression
alias_ = expression.alias
if not alias_:
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
if isinstance(expression, exp.Subquery):
+ table = expression.this
step = Step.from_expression(table, ctes)
step.name = alias_
return step
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
new file mode 100644
index 0000000..c916330
--- /dev/null
+++ b/sqlglot/schema.py
@@ -0,0 +1,297 @@
+import abc
+
+from sqlglot import expressions as exp
+from sqlglot.errors import OptimizeError
+from sqlglot.helper import csv_reader
+
+
+class Schema(abc.ABC):
+ """Abstract base class for database schemas"""
+
+ @abc.abstractmethod
+ def add_table(self, table, column_mapping=None):
+ """
+ Register or update a table. Some implementing classes may require column information to also be provided
+
+ Args:
+ table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
+ column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
+ """
+
+ @abc.abstractmethod
+ def column_names(self, table, only_visible=False):
+ """
+ Get the column names for a table.
+ Args:
+ table (sqlglot.expressions.Table): Table expression instance
+ only_visible (bool): Whether to include invisible columns
+ Returns:
+ list[str]: list of column names
+ """
+
+ @abc.abstractmethod
+ def get_column_type(self, table, column):
+ """
+ Get the exp.DataType type of a column in the schema.
+
+ Args:
+ table (sqlglot.expressions.Table): The source table.
+ column (sqlglot.expressions.Column): The target column.
+ Returns:
+ sqlglot.expressions.DataType.Type: The resulting column type.
+ """
+
+
+class MappingSchema(Schema):
+ """
+ Schema based on a nested mapping.
+
+ Args:
+ schema (dict): Mapping in one of the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ 4. None - Tables will be added later
+ visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
+ are assumed to be visible. The nesting should mirror that of the schema:
+ 1. {table: set(*cols)}}
+ 2. {db: {table: set(*cols)}}}
+ 3. {catalog: {db: {table: set(*cols)}}}}
+ dialect (str): The dialect to be used for custom type mappings.
+ """
+
+ def __init__(self, schema=None, visible=None, dialect=None):
+ self.schema = schema or {}
+ self.visible = visible
+ self.dialect = dialect
+ self._type_mapping_cache = {}
+ self.supported_table_args = []
+ self.forbidden_table_args = set()
+ if self.schema:
+ self._initialize_supported_args()
+
+ @classmethod
+ def from_mapping_schema(cls, mapping_schema):
+ return MappingSchema(
+ schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
+ )
+
+ def copy(self, **kwargs):
+ return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
+
+ def add_table(self, table, column_mapping=None):
+ """
+ Register or update a table. Updates are only performed if a new column mapping is provided.
+
+ Args:
+ table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
+ column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
+ """
+ table = exp.to_table(table)
+ self._validate_table(table)
+ column_mapping = ensure_column_mapping(column_mapping)
+ table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
+ existing_column_mapping = _nested_get(
+ self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
+ )
+ if existing_column_mapping and not column_mapping:
+ return
+ _nested_set(
+ self.schema,
+ [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
+ column_mapping,
+ )
+ self._initialize_supported_args()
+
+ def _get_table_args_from_table(self, table):
+ if table.args.get("catalog") is not None:
+ return "catalog", "db", "this"
+ if table.args.get("db") is not None:
+ return "db", "this"
+ return ("this",)
+
+ def _validate_table(self, table):
+ if not self.supported_table_args and isinstance(table, exp.Table):
+ return
+ for forbidden in self.forbidden_table_args:
+ if table.text(forbidden):
+ raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
+ for expected in self.supported_table_args:
+ if not table.text(expected):
+ raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
+
+ def column_names(self, table, only_visible=False):
+ table = exp.to_table(table)
+ if not isinstance(table.this, exp.Identifier):
+ return fs_get(table)
+
+ args = tuple(table.text(p) for p in self.supported_table_args)
+
+ for forbidden in self.forbidden_table_args:
+ if table.text(forbidden):
+ raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
+
+ columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
+ if not only_visible or not self.visible:
+ return columns
+
+ visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
+ return [col for col in columns if col in visible]
+
+ def get_column_type(self, table, column):
+ try:
+ schema_type = self.schema.get(table.name, {}).get(column.name).upper()
+ return self._convert_type(schema_type)
+ except:
+ raise OptimizeError(f"Failed to get type for column {column.sql()}")
+
+ def _convert_type(self, schema_type):
+ """
+ Convert a type represented as a string to the corresponding exp.DataType.Type object.
+ Args:
+ schema_type (str): The type we want to convert.
+ Returns:
+ sqlglot.expressions.DataType.Type: The resulting expression type.
+ """
+ if schema_type not in self._type_mapping_cache:
+ try:
+ self._type_mapping_cache[schema_type] = exp.maybe_parse(
+ schema_type, into=exp.DataType, dialect=self.dialect
+ ).this
+ except AttributeError:
+ raise OptimizeError(f"Failed to convert type {schema_type}")
+
+ return self._type_mapping_cache[schema_type]
+
+ def _initialize_supported_args(self):
+ if not self.supported_table_args:
+ depth = _dict_depth(self.schema)
+
+ all_args = ["this", "db", "catalog"]
+ if not depth or depth == 1: # {}
+ self.supported_table_args = []
+ elif 2 <= depth <= 4:
+ self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
+ else:
+ raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
+
+ self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
+
+
+def ensure_schema(schema):
+ if isinstance(schema, Schema):
+ return schema
+
+ return MappingSchema(schema)
+
+
+def ensure_column_mapping(mapping):
+ if isinstance(mapping, dict):
+ return mapping
+ elif isinstance(mapping, str):
+ col_name_type_strs = [x.strip() for x in mapping.split(",")]
+ return {
+ name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
+ for name_type_str in col_name_type_strs
+ }
+ # Check if mapping looks like a DataFrame StructType
+ elif hasattr(mapping, "simpleString"):
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
+ elif isinstance(mapping, list):
+ return {x.strip(): None for x in mapping}
+ elif mapping is None:
+ return {}
+ raise ValueError(f"Invalid mapping provided: {type(mapping)}")
+
+
+def fs_get(table):
+ name = table.this.name
+
+ if name.upper() == "READ_CSV":
+ with csv_reader(table) as reader:
+ return next(reader)
+
+ raise ValueError(f"Cannot read schema for {table}")
+
+
+def _nested_get(d, *path, raise_on_missing=True):
+ """
+ Get a value for a nested dictionary.
+
+ Args:
+ d (dict): dictionary
+ *path (tuple[str, str]): tuples of (name, key)
+ `key` is the key in the dictionary to get.
+ `name` is a string to use in the error if `key` isn't found.
+
+ Returns:
+ The value or None if it doesn't exist
+ """
+ for name, key in path:
+ d = d.get(key)
+ if d is None:
+ if raise_on_missing:
+ name = "table" if name == "this" else name
+ raise ValueError(f"Unknown {name}")
+ return None
+ return d
+
+
+def _nested_set(d, keys, value):
+ """
+ In-place set a value for a nested dictionary
+
+ Ex:
+ >>> _nested_set({}, ["top_key", "second_key"], "value")
+ {'top_key': {'second_key': 'value'}}
+ >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
+ {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
+
+ d (dict): dictionary
+ keys (Iterable[str]): ordered iterable of keys that makeup path to value
+ value (Any): The value to set in the dictionary for the given key path
+ """
+ if not keys:
+ return
+ if len(keys) == 1:
+ d[keys[0]] = value
+ return
+ subd = d
+ for key in keys[:-1]:
+ if key not in subd:
+ subd = subd.setdefault(key, {})
+ else:
+ subd = subd[key]
+ subd[keys[-1]] = value
+ return d
+
+
+def _dict_depth(d):
+ """
+ Get the nesting depth of a dictionary.
+
+ For example:
+ >>> _dict_depth(None)
+ 0
+ >>> _dict_depth({})
+ 1
+ >>> _dict_depth({"a": "b"})
+ 1
+ >>> _dict_depth({"a": {}})
+ 2
+ >>> _dict_depth({"a": {"b": {}}})
+ 3
+
+ Args:
+ d (dict): dictionary
+ Returns:
+ int: depth
+ """
+ try:
+ return 1 + _dict_depth(next(iter(d.values())))
+ except AttributeError:
+ # d doesn't have attribute "values"
+ return 0
+ except StopIteration:
+ # d.values() returns an empty sequence
+ return 1
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index fc8e6e7..1a9d72e 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -56,6 +56,7 @@ class TokenType(AutoName):
VAR = auto()
BIT_STRING = auto()
HEX_STRING = auto()
+ BYTE_STRING = auto()
# types
BOOLEAN = auto()
@@ -320,6 +321,7 @@ class _Tokenizer(type):
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
+ klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
@@ -333,6 +335,7 @@ class _Tokenizer(type):
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
+ **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
}.items()
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
@@ -385,6 +388,8 @@ class Tokenizer(metaclass=_Tokenizer):
HEX_STRINGS = []
+ BYTE_STRINGS = []
+
IDENTIFIERS = ['"']
ESCAPE = "'"
@@ -799,7 +804,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word):
return
- if self._scan_numeric_string(word):
+ if self._scan_formatted_string(word):
return
if self._scan_comment(word):
return
@@ -906,7 +911,8 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.STRING, text)
return True
- def _scan_numeric_string(self, string_start):
+ # X'1234, b'0110', E'\\\\\' etc.
+ def _scan_formatted_string(self, string_start):
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
@@ -915,6 +921,10 @@ class Tokenizer(metaclass=_Tokenizer):
delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
+ elif string_start in self._BYTE_STRINGS:
+ delimiters = self._BYTE_STRINGS
+ token_type = TokenType.BYTE_STRING
+ base = None
else:
return False
@@ -922,10 +932,14 @@ class Tokenizer(metaclass=_Tokenizer):
string_end = delimiters.get(string_start)
text = self._extract_string(string_end)
- try:
- self._add(token_type, f"{int(text, base)}")
- except ValueError:
- raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
+ if base is None:
+ self._add(token_type, text)
+ else:
+ try:
+ self._add(token_type, f"{int(text, base)}")
+ except:
+ raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
+
return True
def _scan_identifier(self, identifier_end):