summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py7
-rw-r--r--sqlglot/dialects/clickhouse.py15
-rw-r--r--sqlglot/dialects/dialect.py16
-rw-r--r--sqlglot/dialects/duckdb.py19
-rw-r--r--sqlglot/dialects/hive.py49
-rw-r--r--sqlglot/dialects/mysql.py30
-rw-r--r--sqlglot/dialects/postgres.py6
-rw-r--r--sqlglot/dialects/presto.py5
-rw-r--r--sqlglot/dialects/redshift.py41
-rw-r--r--sqlglot/dialects/snowflake.py23
-rw-r--r--sqlglot/dialects/spark2.py61
-rw-r--r--sqlglot/dialects/starrocks.py2
-rw-r--r--sqlglot/dialects/tsql.py83
13 files changed, 264 insertions, 93 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 35892f7..fd9965c 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
+ binary_from_function,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
@@ -15,6 +16,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
+ regexp_replace_sql,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
@@ -39,7 +41,7 @@ def _date_add_sql(
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
- if not isinstance(expression.unnest().parent, exp.From):
+ if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
alias = expression.args.get("alias")
@@ -279,7 +281,7 @@ class BigQuery(Dialect):
),
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
- "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
+ "DIV": binary_from_function(exp.IntDiv),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _parse_to_hex,
@@ -415,6 +417,7 @@ class BigQuery(Dialect):
e.args.get("position"),
e.args.get("occurrence"),
),
+ exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 9126c4b..8f60df2 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -64,6 +64,7 @@ class ClickHouse(Dialect):
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
+ "XOR": lambda args: exp.Xor(expressions=args),
}
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
@@ -95,6 +96,7 @@ class ClickHouse(Dialect):
TokenType.ASOF,
TokenType.ANTI,
TokenType.SEMI,
+ TokenType.ARRAY,
}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
@@ -103,6 +105,7 @@ class ClickHouse(Dialect):
TokenType.ANTI,
TokenType.SETTINGS,
TokenType.FORMAT,
+ TokenType.ARRAY,
}
LOG_DEFAULTS_TO_LN = True
@@ -160,8 +163,11 @@ class ClickHouse(Dialect):
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
+ parse_bracket: bool = False,
) -> t.Optional[exp.Expression]:
- this = super()._parse_table(schema=schema, joins=joins, alias_tokens=alias_tokens)
+ this = super()._parse_table(
+ schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
+ )
if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this)
@@ -204,8 +210,10 @@ class ClickHouse(Dialect):
self._match_set(self.JOIN_KINDS) and self._prev,
)
- def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
- join = super()._parse_join(skip_join_token)
+ def _parse_join(
+ self, skip_join_token: bool = False, parse_bracket: bool = False
+ ) -> t.Optional[exp.Join]:
+ join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True)
if join:
join.set("global", join.args.pop("method", None))
@@ -318,6 +326,7 @@ class ClickHouse(Dialect):
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
+ exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 5376dff..8c84639 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -12,6 +12,8 @@ from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
+B = t.TypeVar("B", bound=exp.Binary)
+
class Dialects(str, Enum):
DIALECT = ""
@@ -630,6 +632,16 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
)
+def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
+ bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
+ if bad_args:
+ self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
+
+ return self.func(
+ "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
+ )
+
+
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []
for agg in aggregations:
@@ -650,3 +662,7 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
return names
+
+
+def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
+ return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 1d8a7fb..219b1aa 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
+ binary_from_function,
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
@@ -16,6 +17,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
pivot_column_names,
regexp_extract_sql,
+ regexp_replace_sql,
rename_func,
str_position_sql,
str_to_time_sql,
@@ -103,7 +105,6 @@ class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "~": TokenType.RLIKE,
":=": TokenType.EQ,
"//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
@@ -128,6 +129,11 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
+ BITWISE = {
+ **parser.Parser.BITWISE,
+ TokenType.TILDA: exp.RegexpLike,
+ }
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
@@ -158,6 +164,7 @@ class DuckDB(Dialect):
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
+ "XOR": binary_from_function(exp.BitwiseXor),
}
TYPE_TOKENS = {
@@ -190,6 +197,7 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
+ exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
@@ -203,7 +211,7 @@ class DuckDB(Dialect):
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
+ "DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
@@ -217,8 +225,15 @@ class DuckDB(Dialect):
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.MonthsBetween: lambda self, e: self.func(
+ "DATEDIFF",
+ "'month'",
+ exp.cast(e.expression, "timestamp"),
+ exp.cast(e.this, "timestamp"),
+ ),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
+ exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index f968f6a..e131434 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
regexp_extract_sql,
+ regexp_replace_sql,
rename_func,
right_to_substring_sql,
strposition_to_locate_sql,
@@ -211,6 +212,7 @@ class Hive(Dialect):
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
"MSCK REPAIR": TokenType.COMMAND,
+ "REFRESH": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
@@ -270,6 +272,11 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS,
+ "TRANSFORM": lambda self: self._parse_transform(),
+ }
+
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
@@ -277,6 +284,40 @@ class Hive(Dialect):
),
}
+ def _parse_transform(self) -> exp.Transform | exp.QueryTransform:
+ args = self._parse_csv(self._parse_lambda)
+ self._match_r_paren()
+
+ row_format_before = self._parse_row_format(match_row=True)
+
+ record_writer = None
+ if self._match_text_seq("RECORDWRITER"):
+ record_writer = self._parse_string()
+
+ if not self._match(TokenType.USING):
+ return exp.Transform.from_arg_list(args)
+
+ command_script = self._parse_string()
+
+ self._match(TokenType.ALIAS)
+ schema = self._parse_schema()
+
+ row_format_after = self._parse_row_format(match_row=True)
+ record_reader = None
+ if self._match_text_seq("RECORDREADER"):
+ record_reader = self._parse_string()
+
+ return self.expression(
+ exp.QueryTransform,
+ expressions=args,
+ command_script=command_script,
+ schema=schema,
+ row_format_before=row_format_before,
+ record_writer=record_writer,
+ row_format_after=row_format_after,
+ record_reader=record_reader,
+ )
+
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@@ -363,11 +404,13 @@ class Hive(Dialect):
exp.Max: max_or_greatest,
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
+ exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
+ exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.Right: right_to_substring_sql,
@@ -396,7 +439,6 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
- exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
@@ -410,6 +452,11 @@ class Hive(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str:
+ serde_props = self.sql(expression, "serde_properties")
+ serde_props = f" {serde_props}" if serde_props else ""
+ return f"ROW FORMAT SERDE {self.sql(expression, 'this')}{serde_props}"
+
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index e4de934..5d65f77 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -427,6 +427,7 @@ class MySQL(Dialect):
TABLE_HINTS = True
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
+ VALUES_AS_TABLE = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -473,19 +474,32 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
+ # MySQL doesn't support many datatypes in cast.
+ # https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast
+ CAST_MAPPING = {
+ exp.DataType.Type.BIGINT: "SIGNED",
+ exp.DataType.Type.BOOLEAN: "SIGNED",
+ exp.DataType.Type.INT: "SIGNED",
+ exp.DataType.Type.TEXT: "CHAR",
+ exp.DataType.Type.UBIGINT: "UNSIGNED",
+ exp.DataType.Type.VARCHAR: "CHAR",
+ }
+
+ def xor_sql(self, expression: exp.Xor) -> str:
+ if expression.expressions:
+ return self.expressions(expression, sep=" XOR ")
+ return super().xor_sql(expression)
+
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
- """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
- if expression.to.this == exp.DataType.Type.BIGINT:
- to = "SIGNED"
- elif expression.to.this == exp.DataType.Type.UBIGINT:
- to = "UNSIGNED"
- else:
- return super().cast_sql(expression)
+ to = self.CAST_MAPPING.get(expression.to.this)
- return f"CAST({self.sql(expression, 'this')} AS {to})"
+ if to:
+ expression = expression.copy()
+ expression.to.set("this", to)
+ return super().cast_sql(expression)
def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 7706456..d11cbd7 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -282,7 +282,6 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
- STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
@@ -318,6 +317,11 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.END: lambda self: self._parse_commit_or_rollback(),
+ }
+
def _parse_factor(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_exponent, self.FACTOR)
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 7d35c67..265c6e5 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ binary_from_function,
date_trunc_to_time,
format_time_lambda,
if_sql,
@@ -198,6 +199,10 @@ class Presto(Dialect):
**parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
+ "BITWISE_AND": binary_from_function(exp.BitwiseAnd),
+ "BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
+ "BITWISE_OR": binary_from_function(exp.BitwiseOr),
+ "BITWISE_XOR": binary_from_function(exp.BitwiseXor),
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 09edd55..f687ba7 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -27,6 +27,11 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
+ "ADD_MONTHS": lambda args: exp.DateAdd(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=seq_get(args, 1),
+ unit=exp.var("month"),
+ ),
"DATEADD": lambda args: exp.DateAdd(
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=seq_get(args, 1),
@@ -37,7 +42,6 @@ class Redshift(Postgres):
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
unit=seq_get(args, 0),
),
- "NVL": exp.Coalesce.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@@ -87,6 +91,7 @@ class Redshift(Postgres):
LOCKING_READS_SUPPORTED = False
RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
+ VALUES_AS_TABLE = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -129,40 +134,6 @@ class Redshift(Postgres):
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
- def values_sql(self, expression: exp.Values) -> str:
- """
- Converts `VALUES...` expression into a series of unions.
-
- Note: If you have a lot of unions then this will result in a large number of recursive statements to
- evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
- very slow.
- """
-
- # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
- if not expression.find_ancestor(exp.From, exp.Join):
- return super().values_sql(expression)
-
- column_names = expression.alias and expression.args["alias"].columns
-
- selects = []
- rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
-
- for i, row in enumerate(rows):
- if i == 0 and column_names:
- row = [
- exp.alias_(value, column_name)
- for value, column_name in zip(row, column_names)
- ]
-
- selects.append(exp.Select(expressions=row))
-
- subquery_expression: exp.Select | exp.Union = selects[0]
- if len(selects) > 1:
- for select in selects[1:]:
- subquery_expression = exp.union(subquery_expression, select, distinct=False)
-
- return self.subquery_sql(subquery_expression.subquery(expression.alias))
-
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 715a84c..499e085 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
+def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@@ -137,7 +137,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args: t.List) -> exp.Expression:
+def _div0_to_if(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -145,13 +145,13 @@ def _div0_to_if(args: t.List) -> exp.Expression:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args: t.List) -> exp.Expression:
+def _zeroifnull_to_if(args: t.List) -> exp.If:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args: t.List) -> exp.Expression:
+def _nullifzero_to_if(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@@ -164,12 +164,21 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
-def _parse_convert_timezone(args: t.List) -> exp.Expression:
+def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
+def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
+ regexp_replace = exp.RegexpReplace.from_arg_list(args)
+
+ if not regexp_replace.args.get("replacement"):
+ regexp_replace.set("replacement", exp.Literal.string(""))
+
+ return regexp_replace
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
@@ -223,13 +232,14 @@ class Snowflake(Dialect):
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
+ "REGEXP_REPLACE": _parse_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
"TO_ARRAY": exp.Array.from_arg_list,
- "TO_TIMESTAMP": _snowflake_to_timestamp,
+ "TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
}
@@ -242,7 +252,6 @@ class Snowflake(Dialect):
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
- TokenType.RLIKE,
TokenType.TABLE,
}
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index f909e8c..dcaa524 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -2,9 +2,11 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parser, transforms
+from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
+ binary_from_function,
create_with_partitions_sql,
+ format_time_lambda,
pivot_column_names,
rename_func,
trim_sql,
@@ -108,47 +110,36 @@ class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
- "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
- "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
- "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
- "DAYOFWEEK": lambda args: exp.DayOfWeek(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFYEAR": lambda args: exp.DayOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "WEEKOFYEAR": lambda args: exp.WeekOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1),
- unit=exp.var(seq_get(args, 0)),
- ),
- "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": _parse_as_cast("boolean"),
"DATE": _parse_as_cast("date"),
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
+ ),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
+ "IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
+ "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
+ "RLIKE": exp.RegexpLike.from_arg_list,
+ "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
+ "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
+ "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
+ if len(args) == 1
+ else format_time_lambda(exp.StrToTime, "spark")(args),
+ "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
+ **Hive.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -207,6 +198,13 @@ class Spark2(Hive):
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
+ exp.RegexpReplace: lambda self, e: self.func(
+ "REGEXP_REPLACE",
+ e.this,
+ e.expression,
+ e.args["replacement"],
+ e.args.get("position"),
+ ),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
@@ -224,6 +222,7 @@ class Spark2(Hive):
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
+ TRANSFORMS.pop(exp.MonthsBetween)
TRANSFORMS.pop(exp.Right)
WRAP_DERIVED_VALUES = False
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 0390113..baa62e8 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -20,6 +20,8 @@ class StarRocks(MySQL):
}
class Generator(MySQL.Generator):
+ CAST_MAPPING = {}
+
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b77c2c0..01d5001 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
- expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
+ expression.text("format"),
+ t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
)
)
)
@@ -314,7 +315,9 @@ class TSQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
- this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@@ -365,6 +368,55 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
+ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
+ """Applies to SQL Server and Azure SQL Database
+ COMMIT [ { TRAN | TRANSACTION }
+ [ transaction_name | @tran_name_variable ] ]
+ [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ]
+
+ ROLLBACK { TRAN | TRANSACTION }
+ [ transaction_name | @tran_name_variable
+ | savepoint_name | @savepoint_variable ]
+ """
+ rollback = self._prev.token_type == TokenType.ROLLBACK
+
+ self._match_texts({"TRAN", "TRANSACTION"})
+ this = self._parse_id_var()
+
+ if rollback:
+ return self.expression(exp.Rollback, this=this)
+
+ durability = None
+ if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
+ self._match_text_seq("DELAYED_DURABILITY")
+ self._match(TokenType.EQ)
+
+ if self._match_text_seq("OFF"):
+ durability = False
+ else:
+ self._match(TokenType.ON)
+ durability = True
+
+ self._match_r_paren()
+
+ return self.expression(exp.Commit, this=this, durability=durability)
+
+ def _parse_transaction(self) -> exp.Transaction | exp.Command:
+ """Applies to SQL Server and Azure SQL Database
+ BEGIN { TRAN | TRANSACTION }
+ [ { transaction_name | @tran_name_variable }
+ [ WITH MARK [ 'description' ] ]
+ ]
+ """
+ if self._match_texts(("TRAN", "TRANSACTION")):
+ transaction = self.expression(exp.Transaction, this=self._parse_id_var())
+ if self._match_text_seq("WITH", "MARK"):
+ transaction.set("mark", self._parse_string())
+
+ return transaction
+
+ return self._parse_as_command(self._prev)
+
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@@ -496,7 +548,9 @@ class TSQL(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
- "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
+ "HASHBYTES",
+ exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
+ e.this,
),
exp.TimeToStr: _format_sql,
}
@@ -539,3 +593,26 @@ class TSQL(Dialect):
into = self.sql(expression, "into")
into = self.seg(f"INTO {into}") if into else ""
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"
+
+ def transaction_sql(self, expression: exp.Transaction) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ mark = self.sql(expression, "mark")
+ mark = f" WITH MARK {mark}" if mark else ""
+ return f"BEGIN TRANSACTION{this}{mark}"
+
+ def commit_sql(self, expression: exp.Commit) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ durability = expression.args.get("durability")
+ durability = (
+ f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})"
+ if durability is not None
+ else ""
+ )
+ return f"COMMIT TRANSACTION{this}{durability}"
+
+ def rollback_sql(self, expression: exp.Rollback) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ return f"ROLLBACK TRANSACTION{this}"