diff options
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | sqlglot/__init__.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 10 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 252 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 40 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 197 | ||||
-rw-r--r-- | sqlglot/expressions.py | 50 | ||||
-rw-r--r-- | sqlglot/generator.py | 5 | ||||
-rw-r--r-- | sqlglot/helper.py | 20 | ||||
-rw-r--r-- | sqlglot/parser.py | 9 | ||||
-rw-r--r-- | sqlglot/time.py | 2 | ||||
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 54 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 223 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 1 | ||||
-rw-r--r-- | tests/test_build.py | 4 | ||||
-rw-r--r-- | tests/test_expressions.py | 66 | ||||
-rw-r--r-- | tests/test_time.py | 2 |
22 files changed, 743 insertions, 224 deletions
@@ -316,7 +316,7 @@ Dialect["custom"] ## Run Tests and Lint ``` -pip install -r requirements.txt +pip install -r dev-requirements.txt # set `SKIP_INTEGRATION=1` to skip integration tests ./run_checks.sh ``` diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 7841c11..a780f96 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -24,7 +24,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "9.0.1" +__version__ = "9.0.3" pretty = False diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 2391080..e66aaa8 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -5,7 +5,7 @@ import typing as t import sqlglot from sqlglot import expressions as exp from sqlglot.dataframe.sql.types import DataType -from sqlglot.helper import flatten +from sqlglot.helper import flatten, is_iterable if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnOrLiteral @@ -134,10 +134,14 @@ class Column: cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) + ensure_expression_values = { + k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression + for k, v in kwargs.items() + } new_expression = ( - callable_expression(**kwargs) + callable_expression(**ensure_expression_values) if ensured_column is None - else callable_expression(this=ensured_column.column_expression, **kwargs) + else callable_expression(this=ensured_column.column_expression, **ensure_expression_values) ) return Column(new_expression) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 4c6de30..bc002e5 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing as t -from inspect import signature from sqlglot import expressions as glotexp from sqlglot.dataframe.sql.column import Column @@ -24,17 +23,15 @@ def lit(value: t.Optional[t.Any] = None) -> Column: def greatest(*cols: ColumnOrName) -> Column: - columns = [Column.ensure_col(col) for col in cols] - return Column.invoke_expression_over_column( - columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None - ) + if len(cols) > 1: + return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:]) + return Column.invoke_expression_over_column(cols[0], glotexp.Greatest) def least(*cols: ColumnOrName) -> Column: - columns = [Column.ensure_col(col) for col in cols] - return Column.invoke_expression_over_column( - columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None - ) + if len(cols) > 1: + return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:]) + return Column.invoke_expression_over_column(cols[0], glotexp.Least) def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: @@ -194,7 +191,7 @@ def log2(col: ColumnOrName) -> Column: def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: if arg2 is None: return Column.invoke_expression_over_column(arg1, glotexp.Ln) - return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression) + return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2) def rint(col: ColumnOrName) -> Column: @@ -310,7 +307,7 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float] def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_anonymous_function(col1, "POW", col2) + return Column.invoke_expression_over_column(col1, glotexp.Pow, power=col2) def row_number() -> Column: @@ -340,14 +337,13 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: if rsd is None: return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct) - return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression) + return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd) def coalesce(*cols: ColumnOrName) -> Column: - columns = [Column.ensure_col(col) for col in cols] - return Column.invoke_expression_over_column( - columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None - ) + if len(cols) > 1: + return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:]) + return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce) def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: @@ -405,11 +401,13 @@ def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column: def percentile_approx( col: ColumnOrName, percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]], - accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None, + accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None, ) -> Column: if accuracy: - return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy) - return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage) + return Column.invoke_expression_over_column( + col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy + ) + return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage)) def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: @@ -422,7 +420,7 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: if scale is not None: - return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale)) + return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale) return Column.invoke_expression_over_column(col, glotexp.Round) @@ -433,9 +431,7 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: def shiftleft(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column( - col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression - ) + return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits) def shiftLeft(col: ColumnOrName, numBits: int) -> Column: @@ -443,9 +439,7 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column: def shiftright(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column( - col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression - ) + return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits) def shiftRight(col: ColumnOrName, numBits: int) -> Column: @@ -466,8 +460,7 @@ def expr(str: str) -> Column: def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: columns = ensure_list(col) + list(cols) - expressions = [Column.ensure_col(column).expression for column in columns] - return Column(glotexp.Struct(expressions=expressions)) + return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns) def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: @@ -515,7 +508,7 @@ def current_timestamp() -> Column: def date_format(col: ColumnOrName, format: str) -> Column: - return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format)) + return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format)) def year(col: ColumnOrName) -> Column: @@ -563,15 +556,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression) + return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days) def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression) + return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days) def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression) + return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start) def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: @@ -586,8 +579,8 @@ def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optiona def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_anonymous_function(col, "TO_DATE", lit(format)) - return Column.invoke_anonymous_function(col, "TO_DATE") + return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format)) + return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate) def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: @@ -597,11 +590,11 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def trunc(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression) + return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format)) def date_trunc(format: str, timestamp: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression) + return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format)) def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: @@ -614,14 +607,14 @@ def last_day(col: ColumnOrName) -> Column: def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format)) - return Column.invoke_anonymous_function(col, "FROM_UNIXTIME") + return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format)) + return Column.invoke_expression_over_column(col, glotexp.UnixToStr) def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format)) - return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP") + return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format)) + return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: @@ -738,10 +731,7 @@ def trim(col: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - columns = [Column(col) for col in cols] - return Column.invoke_expression_over_column( - None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)] - ) + return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)) def decode(col: ColumnOrName, charset: str) -> Column: @@ -798,18 +788,14 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: - return Column.invoke_expression_over_column( - left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression - ) + return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right) def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) - pos_column = lit(pos) - str_column = Column.ensure_col(str) if pos is not None: - return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column) - return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column) + return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos) + return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) def lpad(col: ColumnOrName, len: int, pad: str) -> Column: @@ -821,15 +807,15 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column: def repeat(col: ColumnOrName, n: int) -> Column: - return Column.invoke_anonymous_function(col, "REPEAT", n) + return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n)) def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: if limit is not None: return Column.invoke_expression_over_column( - str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression + str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit ) - return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression) + return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern)) def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: @@ -879,9 +865,8 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore - cols = [Column.ensure_col(col).expression for col in cols] # type: ignore - return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols) + columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols + return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns) def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: @@ -892,7 +877,7 @@ def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2) + return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2) def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -970,7 +955,7 @@ def posexplode_outer(col: ColumnOrName) -> Column: def get_json_object(col: ColumnOrName, path: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression) + return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path)) def json_tuple(col: ColumnOrName, *fields: str) -> Column: @@ -1031,11 +1016,17 @@ def array_max(col: ColumnOrName) -> Column: def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: if asc is not None: - return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc)) - return Column.invoke_anonymous_function(col, "SORT_ARRAY") + return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc) + return Column.invoke_expression_over_column(col, glotexp.SortArray) -def array_sort(col: ColumnOrName) -> Column: +def array_sort( + col: ColumnOrName, + comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None, +) -> Column: + if comparator is not None: + f_expression = _get_lambda_from_func(comparator) + return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression) return Column.invoke_expression_over_column(col, glotexp.ArraySort) @@ -1108,130 +1099,53 @@ def aggregate( initialValue: ColumnOrName, merge: t.Callable[[Column, Column], Column], finish: t.Optional[t.Callable[[Column], Column]] = None, - accumulator_name: str = "acc", - target_row_name: str = "x", ) -> Column: - merge_exp = glotexp.Lambda( - this=merge(Column(accumulator_name), Column(target_row_name)).expression, - expressions=[ - glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)), - glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)), - ], - ) + merge_exp = _get_lambda_from_func(merge) if finish is not None: - finish_exp = glotexp.Lambda( - this=finish(Column(accumulator_name)).expression, - expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))], - ) + finish_exp = _get_lambda_from_func(finish) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) def transform( - col: ColumnOrName, - f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], - target_row_name: str = "x", - row_count_name: str = "i", + col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]] ) -> Column: - num_arguments = len(signature(f).parameters) - expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))] - columns = [Column(target_row_name)] - if num_arguments > 1: - columns.append(Column(row_count_name)) - expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name))) - - f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions) + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) -def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: - f_expression = glotexp.Lambda( - this=f(Column(target_row_name)).expression, - expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))], - ) +def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression)) -def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: - f_expression = glotexp.Lambda( - this=f(Column(target_row_name)).expression, - expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))], - ) - +def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) -def filter( - col: ColumnOrName, - f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], - target_row_name: str = "x", - row_count_name: str = "i", -) -> Column: - num_arguments = len(signature(f).parameters) - expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))] - columns = [Column(target_row_name)] - if num_arguments > 1: - columns.append(Column(row_count_name)) - expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name))) - - f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions) - return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression)) - - -def zip_with( - left: ColumnOrName, - right: ColumnOrName, - f: t.Callable[[Column, Column], Column], - left_name: str = "x", - right_name: str = "y", -) -> Column: - f_expression = glotexp.Lambda( - this=f(Column(left_name), Column(right_name)).expression, - expressions=[ - glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)), - glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)), - ], - ) +def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column: + f_expression = _get_lambda_from_func(f) + return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) + +def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column: + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) -def transform_keys( - col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" -) -> Column: - f_expression = glotexp.Lambda( - this=f(Column(key_name), Column(value_name)).expression, - expressions=[ - glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), - glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), - ], - ) +def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression)) -def transform_values( - col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" -) -> Column: - f_expression = glotexp.Lambda( - this=f(Column(key_name), Column(value_name)).expression, - expressions=[ - glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), - glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), - ], - ) +def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression)) -def map_filter( - col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" -) -> Column: - f_expression = glotexp.Lambda( - this=f(Column(key_name), Column(value_name)).expression, - expressions=[ - glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), - glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), - ], - ) +def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression)) @@ -1239,20 +1153,18 @@ def map_zip_with( col1: ColumnOrName, col2: ColumnOrName, f: t.Union[t.Callable[[Column, Column, Column], Column]], - key_name: str = "k", - value1: str = "v1", - value2: str = "v2", ) -> Column: - f_expression = glotexp.Lambda( - this=f(Column(key_name), Column(value1), Column(value2)).expression, - expressions=[ - glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), - glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)), - glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)), - ], - ) + f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) def _lambda_quoted(value: str) -> t.Optional[bool]: return False if value == "_" else None + + +def _get_lambda_from_func(lambda_expression: t.Callable): + variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames] + return glotexp.Lambda( + this=lambda_expression(*[Column(x) for x in variables]).expression, + expressions=variables, + ) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 0810e0c..63fdb85 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -18,6 +18,36 @@ from sqlglot.helper import list_get from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer +# (FuncType, Multiplier) +DATE_DELTA_INTERVAL = { + "YEAR": ("ADD_MONTHS", 12), + "MONTH": ("ADD_MONTHS", 1), + "QUARTER": ("ADD_MONTHS", 3), + "WEEK": ("DATE_ADD", 7), + "DAY": ("DATE_ADD", 1), +} + +DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") + + +def _add_date_sql(self, expression): + unit = expression.text("unit").upper() + func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) + modified_increment = ( + int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression + ) + modified_increment = exp.Literal.number(modified_increment) + return f"{func}({self.format_args(expression.this, modified_increment.this)})" + + +def _date_diff_sql(self, expression): + unit = expression.text("unit").upper() + sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" + _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) + multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" + diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" + return f"{diff_sql}{multiplier_sql}" + def _array_sort(self, expression): if expression.expression: @@ -120,10 +150,14 @@ class Hive(Dialect): "m": "%-M", "ss": "%S", "s": "%-S", - "S": "%f", + "SSSSSS": "%f", "a": "%p", "DD": "%j", "D": "%-j", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", } date_format = "'yyyy-MM-dd'" @@ -207,8 +241,8 @@ class Hive(Dialect): exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort, exp.With: no_recursive_cte_sql, - exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateAdd: _add_date_sql, + exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 6bf4ff0..572f411 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -71,6 +71,7 @@ class Spark(Hive): length=list_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "IIF": exp.If.from_arg_list, } FUNCTION_PARSERS = { @@ -111,6 +112,7 @@ class Spark(Hive): exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), + exp.DateFromParts: rename_func("MAKE_DATE"), } WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index ef8c82d..0cba6fe 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL @@ -14,6 +14,8 @@ class StarRocks(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, + exp.JSONExtractScalar: arrow_json_extract_sql, + exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: rename_func("TO_DATE"), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 1f2e50d..107ace7 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,14 +1,149 @@ from sqlglot import exp -from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.dialect import Dialect, rename_func +from sqlglot.expressions import DataType from sqlglot.generator import Generator +from sqlglot.helper import list_get from sqlglot.parser import Parser +from sqlglot.time import format_time from sqlglot.tokens import Tokenizer, TokenType +FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"} +DATE_DELTA_INTERVAL = { + "year": "year", + "yyyy": "year", + "yy": "year", + "quarter": "quarter", + "qq": "quarter", + "q": "quarter", + "month": "month", + "mm": "month", + "m": "month", + "week": "week", + "ww": "week", + "wk": "week", + "day": "day", + "dd": "day", + "d": "day", +} + + +def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): + def _format_time(args): + return exp_class( + this=list_get(args, 1), + format=exp.Literal.string( + format_time( + list_get(args, 0).name or (TSQL.time_format if default is True else default), + {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, + ) + ), + ) + + return _format_time + + +def parse_date_delta(exp_class): + def inner_func(args): + unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day") + return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit) + + return inner_func + + +def generate_date_delta(self, e): + func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" + return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" + time_mapping = { + "yyyy": "%Y", + "yy": "%y", + "year": "%Y", + "qq": "%q", + "q": "%q", + "quarter": "%q", + "dayofyear": "%j", + "day": "%d", + "dy": "%d", + "y": "%Y", + "week": "%W", + "ww": "%W", + "wk": "%W", + "hour": "%h", + "hh": "%I", + "minute": "%M", + "mi": "%M", + "n": "%M", + "second": "%S", + "ss": "%S", + "s": "%-S", + "millisecond": "%f", + "ms": "%f", + "weekday": "%W", + "dw": "%W", + "month": "%m", + "mm": "%M", + "m": "%-M", + "Y": "%Y", + "YYYY": "%Y", + "YY": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "h": "%-I", + "S": "%f", + } + + convert_format_mapping = { + "0": "%b %d %Y %-I:%M%p", + "1": "%m/%d/%y", + "2": "%y.%m.%d", + "3": "%d/%m/%y", + "4": "%d.%m.%y", + "5": "%d-%m-%y", + "6": "%d %b %y", + "7": "%b %d, %y", + "8": "%H:%M:%S", + "9": "%b %d %Y %-I:%M:%S:%f%p", + "10": "mm-dd-yy", + "11": "yy/mm/dd", + "12": "yymmdd", + "13": "%d %b %Y %H:%M:ss:%f", + "14": "%H:%M:%S:%f", + "20": "%Y-%m-%d %H:%M:%S", + "21": "%Y-%m-%d %H:%M:%S.%f", + "22": "%m/%d/%y %-I:%M:%S %p", + "23": "%Y-%m-%d", + "24": "%H:%M:%S", + "25": "%Y-%m-%d %H:%M:%S.%f", + "100": "%b %d %Y %-I:%M%p", + "101": "%m/%d/%Y", + "102": "%Y.%m.%d", + "103": "%d/%m/%Y", + "104": "%d.%m.%Y", + "105": "%d-%m-%Y", + "106": "%d %b %Y", + "107": "%b %d, %Y", + "108": "%H:%M:%S", + "109": "%b %d %Y %-I:%M:%S:%f%p", + "110": "%m-%d-%Y", + "111": "%Y/%m/%d", + "112": "%Y%m%d", + "113": "%d %b %Y %H:%M:%S:%f", + "114": "%H:%M:%S:%f", + "120": "%Y-%m-%d %H:%M:%S", + "121": "%Y-%m-%d %H:%M:%S.%f", + } + class Tokenizer(Tokenizer): IDENTIFIERS = ['"', ("[", "]")] @@ -29,19 +164,67 @@ class TSQL(Dialect): "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "XML": TokenType.XML, "SQL_VARIANT": TokenType.VARIANT, + "NVARCHAR(MAX)": TokenType.TEXT, + "VARCHAR(MAX)": TokenType.TEXT, } class Parser(Parser): FUNCTIONS = { **Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, + "ISNULL": exp.Coalesce.from_arg_list, + "DATEADD": parse_date_delta(exp.DateAdd), + "DATEDIFF": parse_date_delta(exp.DateDiff), + "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), + "DATEPART": tsql_format_time_lambda(exp.TimeToStr), + "GETDATE": exp.CurrentDate.from_arg_list, + "IIF": exp.If.from_arg_list, + "LEN": exp.Length.from_arg_list, + "REPLICATE": exp.Repeat.from_arg_list, + "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, + } + + VAR_LENGTH_DATATYPES = { + DataType.Type.NVARCHAR, + DataType.Type.VARCHAR, + DataType.Type.CHAR, + DataType.Type.NCHAR, } - def _parse_convert(self): + def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_field() - return self.expression(exp.Cast, this=this, to=to) + + # Retrieve length of datatype and override to default if not specified + if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: + to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) + + # Check whether a conversion with format is applicable + if self._match(TokenType.COMMA): + format_val = self._parse_number().name + if format_val not in TSQL.convert_format_mapping: + raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}") + format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) + + # Check whether the convert entails a string to date format + if to.this == DataType.Type.DATE: + return self.expression(exp.StrToDate, this=this, format=format_norm) + # Check whether the convert entails a string to datetime format + elif to.this == DataType.Type.DATETIME: + return self.expression(exp.StrToTime, this=this, format=format_norm) + # Check whether the convert entails a date to string format + elif to.this in self.VAR_LENGTH_DATATYPES: + return self.expression( + exp.Cast if strict else exp.TryCast, + to=to, + this=self.expression(exp.TimeToStr, this=this, format=format_norm), + ) + elif to.this == DataType.Type.TEXT: + return self.expression(exp.TimeToStr, this=this, format=format_norm) + + # Entails a simple cast without any format requirement + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) class Generator(Generator): TYPE_MAPPING = { @@ -52,3 +235,11 @@ class TSQL(Dialect): exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.VARIANT: "SQL_VARIANT", } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.DateAdd: lambda self, e: generate_date_delta(self, e), + exp.DateDiff: lambda self, e: generate_date_delta(self, e), + exp.CurrentDate: rename_func("GETDATE"), + exp.If: rename_func("IIF"), + } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index f7717c8..eb7854a 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2411,6 +2411,11 @@ class TimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class DateFromParts(Func): + _sql_names = ["DATEFROMPARTS"] + arg_types = {"year": True, "month": True, "day": True} + + class DateStrToDate(Func): pass @@ -2554,7 +2559,7 @@ class Quantile(AggFunc): class ApproxQuantile(Quantile): - pass + arg_types = {"this": True, "quantile": True, "accuracy": False} class Reduce(Func): @@ -2569,6 +2574,10 @@ class RegexpSplit(Func): arg_types = {"this": True, "expression": True} +class Repeat(Func): + arg_types = {"this": True, "times": True} + + class Round(Func): arg_types = {"this": True, "decimals": False} @@ -2690,7 +2699,7 @@ class TsOrDiToDi(Func): class UnixToStr(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": True, "format": False} class UnixToTime(Func): @@ -3077,6 +3086,8 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): ) if from_: update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) + if isinstance(where, Condition): + where = Where(this=where) if where: update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) return update @@ -3518,6 +3529,41 @@ def replace_tables(expression, mapping): return expression.transform(_replace_tables) +def replace_placeholders(expression, *args, **kwargs): + """Replace placeholders in an expression. + + Args: + expression (sqlglot.Expression): Expression node to be transformed and replaced + args: Positional names that will substitute unnamed placeholders in the given order + kwargs: Keyword arguments that will substitute named placeholders + + Examples: + >>> from sqlglot import exp, parse_one + >>> replace_placeholders( + ... parse_one("select * from :tbl where ? = ?"), "a", "b", tbl="foo" + ... ).sql() + 'SELECT * FROM foo WHERE a = b' + + Returns: + The mapped expression + """ + + def _replace_placeholders(node, args, **kwargs): + if isinstance(node, Placeholder): + if node.name: + new_name = kwargs.get(node.name) + if new_name: + return to_identifier(new_name) + else: + try: + return to_identifier(next(args)) + except StopIteration: + pass + return node + + return expression.transform(_replace_placeholders, iter(args), **kwargs) + + TRUE = Boolean(this=True) FALSE = Boolean(this=False) NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 6decd16..1784287 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -47,7 +47,8 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - annotations: Whether or not to show annotations in the SQL. + annotations: Whether or not to show annotations in the SQL when `pretty` is True. + Annotations can only be shown in pretty mode otherwise they may clobber resulting sql. Default: True """ @@ -280,7 +281,7 @@ class Generator: raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") def annotation_sql(self, expression): - if self._annotations: + if self._annotations and self.pretty: return f"{self.sql(expression, 'expression')} # {expression.name}" return self.sql(expression, "expression") diff --git a/sqlglot/helper.py b/sqlglot/helper.py index c3a23d3..42965d1 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -194,6 +194,24 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b return words + [None] * (min_num_words - len(words)) +def is_iterable(value: t.Any) -> bool: + """ + Checks if the value is an iterable but does not include strings and bytes + + Examples: + >>> is_iterable([1,2]) + True + >>> is_iterable("test") + False + + Args: + value: The value to check if it is an interable + + Returns: Bool indicating if it is an iterable + """ + return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) + + def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: """ Flattens a list that can contain both iterables and non-iterable elements @@ -211,7 +229,7 @@ def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generato Yields non-iterable elements (not including str or byte as iterable) """ for value in values: - if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + if is_iterable(value): yield from flatten(value) else: yield value diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 47c1c1d..b94313a 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -433,7 +433,8 @@ class Parser: } FUNCTION_PARSERS = { - "CONVERT": lambda self: self._parse_convert(), + "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), + "TRY_CONVERT": lambda self: self._parse_convert(False), "EXTRACT": lambda self: self._parse_extract(), "POSITION": lambda self: self._parse_position(), "SUBSTRING": lambda self: self._parse_substring(), @@ -1512,7 +1513,7 @@ class Parser: return this def _parse_offset(self, this=None): - if not self._match(TokenType.OFFSET): + if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): return this count = self._parse_number() self._match_set((TokenType.ROW, TokenType.ROWS)) @@ -2134,7 +2135,7 @@ class Parser: return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_convert(self): + def _parse_convert(self, strict): this = self._parse_field() if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) @@ -2142,7 +2143,7 @@ class Parser: to = self._parse_types() else: to = None - return self.expression(exp.Cast, this=this, to=to) + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) def _parse_position(self): args = self._parse_csv(self._parse_bitwise) diff --git a/sqlglot/time.py b/sqlglot/time.py index 16314c5..de28ac0 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -14,6 +14,8 @@ def format_time(string, mapping, trie=None): mapping: Dictionary of time format to target time format trie: Optional trie, can be passed in for performance """ + if not string: + return None start = 0 end = 1 size = len(string) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 10f3b57..97753bd 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -9,7 +9,6 @@ from sqlglot.errors import ErrorLevel class TestFunctions(unittest.TestCase): - @unittest.skip("not yet fixed.") def test_invoke_anonymous(self): for name, func in inspect.getmembers(SF, inspect.isfunction): with self.subTest(f"{name} should not invoke anonymous_function"): @@ -438,13 +437,13 @@ class TestFunctions(unittest.TestCase): def test_pow(self): col_str = SF.pow("cola", "colb") - self.assertEqual("POW(cola, colb)", col_str.sql()) + self.assertEqual("POWER(cola, colb)", col_str.sql()) col = SF.pow(SF.col("cola"), SF.col("colb")) - self.assertEqual("POW(cola, colb)", col.sql()) + self.assertEqual("POWER(cola, colb)", col.sql()) col_float = SF.pow(10.10, "colb") - self.assertEqual("POW(10.1, colb)", col_float.sql()) + self.assertEqual("POWER(10.1, colb)", col_float.sql()) col_float2 = SF.pow("cola", 10.10) - self.assertEqual("POW(cola, 10.1)", col_float2.sql()) + self.assertEqual("POWER(cola, 10.1)", col_float2.sql()) def test_row_number(self): col_str = SF.row_number() @@ -493,6 +492,8 @@ class TestFunctions(unittest.TestCase): self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql()) col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc")) self.assertEqual("COALESCE(cola, colb, colc)", col.sql()) + col_single = SF.coalesce("cola") + self.assertEqual("COALESCE(cola)", col_single.sql()) def test_corr(self): col_str = SF.corr("cola", "colb") @@ -843,8 +844,8 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TO_DATE(cola)", col_str.sql()) col = SF.to_date(SF.col("cola")) self.assertEqual("TO_DATE(cola)", col.sql()) - col_with_format = SF.to_date("cola", "yyyy-MM-dd") - self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql()) + col_with_format = SF.to_date("cola", "yy-MM-dd") + self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql()) def test_to_timestamp(self): col_str = SF.to_timestamp("cola") @@ -883,16 +884,16 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql()) col = SF.from_unixtime(SF.col("cola")) self.assertEqual("FROM_UNIXTIME(cola)", col.sql()) - col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss") - self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm") + self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql()) def test_unix_timestamp(self): col_str = SF.unix_timestamp("cola") self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql()) col = SF.unix_timestamp(SF.col("cola")) self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql()) - col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss") - self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm") + self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql()) col_current = SF.unix_timestamp() self.assertEqual("UNIX_TIMESTAMP()", col_current.sql()) @@ -1427,6 +1428,13 @@ class TestFunctions(unittest.TestCase): self.assertEqual("ARRAY_SORT(cola)", col_str.sql()) col = SF.array_sort(SF.col("cola")) self.assertEqual("ARRAY_SORT(cola)", col.sql()) + col_comparator = SF.array_sort( + "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) + ) + self.assertEqual( + "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", + col_comparator.sql(), + ) def test_reverse(self): col_str = SF.reverse("cola") @@ -1514,8 +1522,6 @@ class TestFunctions(unittest.TestCase): SF.lit(0), lambda accumulator, target: accumulator + target, lambda accumulator: accumulator * 2, - "accumulator", - "target", ) self.assertEqual( "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)", @@ -1527,7 +1533,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql()) col = SF.transform(SF.col("cola"), lambda x, i: x * i) self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) - col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count") + col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) @@ -1536,7 +1542,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql()) col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0) self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql()) - col_custom_name = SF.exists("cola", lambda target: target > 0, "target") + col_custom_name = SF.exists("cola", lambda target: target > 0) self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql()) def test_forall(self): @@ -1544,7 +1550,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql()) col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo")) self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql()) - col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target") + col_custom_name = SF.forall("cola", lambda target: target.rlike("foo")) self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql()) def test_filter(self): @@ -1552,9 +1558,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) - col_custom_names = SF.filter( - "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count" - ) + col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) self.assertEqual( "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() @@ -1565,7 +1569,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql()) col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) - col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r") + col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) def test_transform_keys(self): @@ -1573,7 +1577,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql()) col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k)) self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql()) - col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_") + col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key)) self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql()) def test_transform_values(self): @@ -1581,7 +1585,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql()) col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) - col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value") + col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) def test_map_filter(self): @@ -1589,5 +1593,9 @@ class TestFunctions(unittest.TestCase): self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql()) col = SF.map_filter(SF.col("cola"), lambda k, v: k > v) self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql()) - col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value") + col_custom_names = SF.map_filter("cola", lambda key, value: key > value) self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql()) + + def test_map_zip_with(self): + col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2)) + self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql()) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 14fea9d..050d41e 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -106,6 +106,15 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "CURRENT_DATE", + read={ + "tsql": "GETDATE()", + }, + write={ + "tsql": "GETDATE()", + }, + ) + self.validate_all( "current_datetime", write={ "bigquery": "CURRENT_DATETIME()", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index e1524e9..5d1cf13 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -434,12 +434,7 @@ class TestDialect(Validator): "presto": "DATE_ADD('day', 1, x)", "spark": "DATE_ADD(x, 1)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", - }, - ) - self.validate_all( - "DATE_ADD(x, y, 'day')", - write={ - "postgres": UnsupportedError, + "tsql": "DATEADD(day, 1, x)", }, ) self.validate_all( @@ -634,11 +629,13 @@ class TestDialect(Validator): read={ "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", + "starrocks": "x->'y'", }, write={ "oracle": "JSON_EXTRACT(x, 'y')", "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", + "starrocks": "x->'y'", }, ) self.validate_all( @@ -983,6 +980,7 @@ class TestDialect(Validator): ) def test_limit(self): + self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) self.validate_all( "SELECT x FROM y LIMIT 10", write={ diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 339d1a6..8605bd1 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -282,3 +282,6 @@ TBLPROPERTIES ( "spark": "SELECT ARRAY_SORT(x)", }, ) + + def test_iif(self): + self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 9a6bc36..2a20163 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -71,3 +71,226 @@ class TestTSQL(Validator): "spark": "LOCATE('sub', 'testsubstring')", }, ) + + def test_len(self): + self.validate_all("LEN(x)", write={"spark": "LENGTH(x)"}) + + def test_replicate(self): + self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"}) + + def test_isnull(self): + self.validate_all("ISNULL(x, y)", write={"spark": "COALESCE(x, y)"}) + + def test_jsonvalue(self): + self.validate_all( + "JSON_VALUE(r.JSON, '$.Attr_INT')", + write={"spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')"}, + ) + + def test_datefromparts(self): + self.validate_all( + "SELECT DATEFROMPARTS('2020', 10, 01)", + write={"spark": "SELECT MAKE_DATE('2020', 10, 01)"}, + ) + + def test_datename(self): + self.validate_all( + "SELECT DATENAME(mm,'01-01-1970')", + write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MMMM')"}, + ) + self.validate_all( + "SELECT DATENAME(dw,'01-01-1970')", + write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'EEEE')"}, + ) + + def test_datepart(self): + self.validate_all( + "SELECT DATEPART(month,'01-01-1970')", + write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MM')"}, + ) + + def test_convert_date_format(self): + self.validate_all( + "CONVERT(NVARCHAR(200), x)", + write={ + "spark": "CAST(x AS VARCHAR(200))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR, x)", + write={ + "spark": "CAST(x AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR(MAX), x)", + write={ + "spark": "CAST(x AS STRING)", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(200), x)", + write={ + "spark": "CAST(x AS VARCHAR(200))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR, x)", + write={ + "spark": "CAST(x AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(MAX), x)", + write={ + "spark": "CAST(x AS STRING)", + }, + ) + self.validate_all( + "CONVERT(CHAR(40), x)", + write={ + "spark": "CAST(x AS CHAR(40))", + }, + ) + self.validate_all( + "CONVERT(CHAR, x)", + write={ + "spark": "CAST(x AS CHAR(30))", + }, + ) + self.validate_all( + "CONVERT(NCHAR(40), x)", + write={ + "spark": "CAST(x AS CHAR(40))", + }, + ) + self.validate_all( + "CONVERT(NCHAR, x)", + write={ + "spark": "CAST(x AS CHAR(30))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR, x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(40), x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))", + }, + ) + self.validate_all( + "CONVERT(VARCHAR(MAX), x, 121)", + write={ + "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR, x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR(40), x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))", + }, + ) + self.validate_all( + "CONVERT(NVARCHAR(MAX), x, 121)", + write={ + "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(DATE, x, 121)", + write={ + "spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(DATETIME, x, 121)", + write={ + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(DATETIME2, x, 121)", + write={ + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + }, + ) + self.validate_all( + "CONVERT(INT, x)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "CONVERT(INT, x, 121)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "TRY_CONVERT(NVARCHAR, x, 121)", + write={ + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + }, + ) + self.validate_all( + "TRY_CONVERT(INT, x)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "TRY_CAST(x AS INT)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "CAST(x AS INT)", + write={ + "spark": "CAST(x AS INT)", + }, + ) + + def test_add_date(self): + self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") + self.validate_all( + "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} + ) + self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) + self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"}) + + def test_date_diff(self): + self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") + self.validate_all( + "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + write={ + "tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + "spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", + }, + ) + self.validate_all( + "SELECT DATEDIFF(month, 'start','end')", + write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"}, + ) + self.validate_all( + "SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"} + ) + + def test_iif(self): + self.validate_identity("SELECT IIF(cond, 'True', 'False')") + self.validate_all( + "SELECT IIF(cond, 'True', 'False');", + write={ + "spark": "SELECT IF(cond, 'True', 'False')", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 67e4cab..d7084ac 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -149,7 +149,6 @@ SELECT 1 AS count FROM test SELECT 1 AS comment FROM test SELECT 1 AS numeric FROM test SELECT 1 AS number FROM test -SELECT 1 AS number # annotation SELECT t.count SELECT DISTINCT x FROM test SELECT DISTINCT x, y FROM test diff --git a/tests/test_build.py b/tests/test_build.py index a432ef1..f51996d 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -330,6 +330,10 @@ class TestBuild(unittest.TestCase): "UPDATE tbl SET x = 1 WHERE y > 0", ), ( + lambda: exp.update("tbl", {"x": 1}, where=exp.condition("y > 0")), + "UPDATE tbl SET x = 1 WHERE y > 0", + ), + ( lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), "UPDATE tbl SET x = 1 FROM tbl2", ), diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 79b4ee5..9af59d9 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -135,6 +135,53 @@ class TestExpressions(unittest.TestCase): "SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a", ) + def test_replace_placeholders(self): + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from :tbl1 JOIN :tbl2 ON :col1 = :col2 WHERE :col3 > 100"), + tbl1="foo", + tbl2="bar", + col1="a", + col2="b", + col3="c", + ).sql(), + "SELECT * FROM foo JOIN bar ON a = b WHERE c > 100", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from ? JOIN ? ON ? = ? WHERE ? > 100"), + "foo", + "bar", + "a", + "b", + "c", + ).sql(), + "SELECT * FROM foo JOIN bar ON a = b WHERE c > 100", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from ? WHERE ? > 100"), + "foo", + ).sql(), + "SELECT * FROM foo WHERE ? > 100", + ) + self.assertEqual( + exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(), + "SELECT * FROM :name WHERE ? > 100", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from (SELECT :col1 FROM ?) WHERE :col2 > 100"), + "tbl1", + "tbl2", + "tbl3", + col1="a", + col2="b", + col3="c", + ).sql(), + "SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100", + ) + def test_named_selects(self): expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) @@ -504,9 +551,24 @@ class TestExpressions(unittest.TestCase): [e.alias_or_name for e in expression.expressions], ["a", "B", "c", "D"], ) - self.assertEqual(expression.sql(), sql) + self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D") self.assertEqual(expression.expressions[2].name, "comment") - self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D") + self.assertEqual( + expression.sql(pretty=True, annotations=False), + """SELECT + a, + b AS B, + c, + d AS D""", + ) + self.assertEqual( + expression.sql(pretty=True), + """SELECT + a, + b AS B, + c # comment, + d AS D # another_comment FROM foo""", + ) def test_to_table(self): table_only = exp.to_table("table_name") diff --git a/tests/test_time.py b/tests/test_time.py index 17821c2..bd0e63f 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -5,7 +5,7 @@ from sqlglot.time import format_time class TestTime(unittest.TestCase): def test_format_time(self): - self.assertEqual(format_time("", {}), "") + self.assertEqual(format_time("", {}), None) self.assertEqual(format_time(" ", {}), " ") mapping = {"a": "b", "aa": "c"} self.assertEqual(format_time("a", mapping), "b") |