summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:42 +0000
commitc66e4a33e1a07c439f03fe47f146a6c6482bf6df (patch)
treecfdf01111c063b3e50841695e6c2768833aea4dc /sqlglot
parentReleasing debian version 20.11.0-1. (diff)
downloadsqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.tar.xz
sqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.zip
Merging upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py7
-rw-r--r--sqlglot/_typing.py1
-rw-r--r--sqlglot/dataframe/sql/dataframe.py6
-rw-r--r--sqlglot/dataframe/sql/functions.py35
-rw-r--r--sqlglot/dialects/__init__.py3
-rw-r--r--sqlglot/dialects/bigquery.py35
-rw-r--r--sqlglot/dialects/clickhouse.py35
-rw-r--r--sqlglot/dialects/databricks.py4
-rw-r--r--sqlglot/dialects/dialect.py132
-rw-r--r--sqlglot/dialects/doris.py9
-rw-r--r--sqlglot/dialects/drill.py22
-rw-r--r--sqlglot/dialects/duckdb.py65
-rw-r--r--sqlglot/dialects/hive.py52
-rw-r--r--sqlglot/dialects/mysql.py17
-rw-r--r--sqlglot/dialects/oracle.py6
-rw-r--r--sqlglot/dialects/postgres.py26
-rw-r--r--sqlglot/dialects/presto.py53
-rw-r--r--sqlglot/dialects/redshift.py35
-rw-r--r--sqlglot/dialects/snowflake.py53
-rw-r--r--sqlglot/dialects/spark.py5
-rw-r--r--sqlglot/dialects/spark2.py12
-rw-r--r--sqlglot/dialects/sqlite.py32
-rw-r--r--sqlglot/dialects/starrocks.py6
-rw-r--r--sqlglot/dialects/teradata.py3
-rw-r--r--sqlglot/dialects/trino.py9
-rw-r--r--sqlglot/dialects/tsql.py74
-rw-r--r--sqlglot/executor/__init__.py2
-rw-r--r--sqlglot/executor/context.py4
-rw-r--r--sqlglot/executor/env.py20
-rw-r--r--sqlglot/executor/python.py16
-rw-r--r--sqlglot/expressions.py197
-rw-r--r--sqlglot/generator.py162
-rw-r--r--sqlglot/helper.py22
-rw-r--r--sqlglot/jsonpath.py132
-rw-r--r--sqlglot/optimizer/__init__.py2
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py6
-rw-r--r--sqlglot/optimizer/pushdown_projections.py8
-rw-r--r--sqlglot/optimizer/simplify.py26
-rw-r--r--sqlglot/parser.py99
-rw-r--r--sqlglot/tokens.py6
-rw-r--r--sqlglot/transforms.py4
41 files changed, 1009 insertions, 434 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index d71c06d..2207a28 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -1,3 +1,4 @@
+# ruff: noqa: F401
"""
.. include:: ../README.md
@@ -87,11 +88,13 @@ def parse(
@t.overload
-def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
+def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
+ ...
@t.overload
-def parse_one(sql: str, **opts) -> Expression: ...
+def parse_one(sql: str, **opts) -> Expression:
+ ...
def parse_one(
diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py
index 65f307e..0415aa4 100644
--- a/sqlglot/_typing.py
+++ b/sqlglot/_typing.py
@@ -13,4 +13,5 @@ if t.TYPE_CHECKING:
A = t.TypeVar("A", bound=t.Any)
B = t.TypeVar("B", bound="sqlglot.exp.Binary")
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
+F = t.TypeVar("F", bound="sqlglot.exp.Func")
T = t.TypeVar("T")
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 0bacbf9..7e3f07b 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -140,10 +140,12 @@ class DataFrame:
return cte, name
@t.overload
- def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
+ ...
@t.overload
- def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
+ ...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index a388cb4..29e7c55 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -368,7 +368,10 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
+ this = Column.invoke_expression_over_column(col, expression.First)
+ if ignorenulls:
+ return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
+ return this
def grouping_id(*cols: ColumnOrName) -> Column:
@@ -392,7 +395,10 @@ def isnull(col: ColumnOrName) -> Column:
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
+ this = Column.invoke_expression_over_column(col, expression.Last)
+ if ignorenulls:
+ return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
+ return this
def monotonically_increasing_id() -> Column:
@@ -485,31 +491,28 @@ def factorial(col: ColumnOrName) -> Column:
def lag(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
) -> Column:
- if default is not None:
- return Column.invoke_anonymous_function(col, "LAG", offset, default)
- if offset != 1:
- return Column.invoke_anonymous_function(col, "LAG", offset)
- return Column.invoke_anonymous_function(col, "LAG")
+ return Column.invoke_expression_over_column(
+ col, expression.Lag, offset=None if offset == 1 else offset, default=default
+ )
def lead(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
) -> Column:
- if default is not None:
- return Column.invoke_anonymous_function(col, "LEAD", offset, default)
- if offset != 1:
- return Column.invoke_anonymous_function(col, "LEAD", offset)
- return Column.invoke_anonymous_function(col, "LEAD")
+ return Column.invoke_expression_over_column(
+ col, expression.Lead, offset=None if offset == 1 else offset, default=default
+ )
def nth_value(
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
) -> Column:
+ this = Column.invoke_expression_over_column(
+ col, expression.NthValue, offset=None if offset == 1 else offset
+ )
if ignoreNulls is not None:
- raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
- if offset != 1:
- return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
- return Column.invoke_anonymous_function(col, "NTH_VALUE")
+ return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
+ return this
def ntile(n: int) -> Column:
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 04990ac..82552c9 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -1,9 +1,10 @@
+# ruff: noqa: F401
"""
## Dialects
While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult
to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible
-SQL transpilation framework.
+SQL transpilation framework.
The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible.
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 771ae1a..9068235 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
- path_to_jsonpath,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
@@ -458,8 +457,10 @@ class BigQuery(Dialect):
return this
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
- table = super()._parse_table_parts(schema=schema)
+ def _parse_table_parts(
+ self, schema: bool = False, is_db_reference: bool = False
+ ) -> exp.Table:
+ table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
if isinstance(table.this, exp.Identifier) and "." in table.name:
catalog, db, this, *rest = (
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
@@ -474,10 +475,12 @@ class BigQuery(Dialect):
return table
@t.overload
- def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
+ ...
@t.overload
- def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
+ ...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
@@ -536,6 +539,8 @@ class BigQuery(Dialect):
UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
JSON_KEY_VALUE_PAIR_SEP = ","
NULL_ORDERING_SUPPORTED = False
+ IGNORE_NULLS_IN_FUNC = True
+ JSON_PATH_SINGLE_QUOTE_ESCAPE = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -554,7 +559,8 @@ class BigQuery(Dialect):
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
- exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
+ exp.DateDiff: lambda self,
+ e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
@@ -565,7 +571,6 @@ class BigQuery(Dialect):
"DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
- exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(false_value="NULL"),
@@ -597,12 +602,13 @@ class BigQuery(Dialect):
]
),
exp.SHA2: lambda self, e: self.func(
- f"SHA256" if e.text("length") == "256" else "SHA512", e.this
+ "SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.StabilityProperty: lambda self, e: (
- f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
+ "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
- exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
+ exp.StrToDate: lambda self,
+ e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
@@ -610,9 +616,10 @@ class BigQuery(Dialect):
exp.TimeFromParts: rename_func("TIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
+ exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
- exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
+ exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToTime: rename_func("TIME"),
@@ -623,6 +630,12 @@ class BigQuery(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 1248edc..1ec15c5 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -8,12 +8,15 @@ from sqlglot.dialects.dialect import (
arg_max_or_min_no_count,
date_delta_sql,
inline_array_sql,
+ json_extract_segments,
+ json_path_key_only_name,
no_pivot_sql,
+ parse_json_extract_path,
rename_func,
var_map_sql,
)
from sqlglot.errors import ParseError
-from sqlglot.helper import seq_get
+from sqlglot.helper import is_int, seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import Token, TokenType
@@ -120,6 +123,9 @@ class ClickHouse(Dialect):
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
+ "JSONEXTRACTSTRING": parse_json_extract_path(
+ exp.JSONExtractScalar, zero_based_indexing=False
+ ),
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"RANDCANONICAL": exp.Rand.from_arg_list,
@@ -354,9 +360,14 @@ class ClickHouse(Dialect):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
+ is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
this = super()._parse_table(
- schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
+ schema=schema,
+ joins=joins,
+ alias_tokens=alias_tokens,
+ parse_bracket=parse_bracket,
+ is_db_reference=is_db_reference,
)
if self._match(TokenType.FINAL):
@@ -518,6 +529,12 @@ class ClickHouse(Dialect):
exp.DataType.Type.VARCHAR: "String",
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**STRING_TYPE_MAPPING,
@@ -570,6 +587,10 @@ class ClickHouse(Dialect):
exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
+ exp.JSONExtract: json_extract_segments("JSONExtractString", quoted_index=False),
+ exp.JSONExtractScalar: json_extract_segments("JSONExtractString", quoted_index=False),
+ exp.JSONPathKey: json_path_key_only_name,
+ exp.JSONPathRoot: lambda *_: "",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Nullif: rename_func("nullIf"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -579,7 +600,8 @@ class ClickHouse(Dialect):
exp.Rand: rename_func("randCanonical"),
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
- exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
+ exp.StrPosition: lambda self,
+ e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}
@@ -608,6 +630,13 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
+ def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
+ this = self.json_path_part(expression.this)
+ return str(int(this) + 1) if is_int(this) else this
+
+ def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
+ return f"AS {self.sql(expression, 'this')}"
+
def _any_to_has(
self,
expression: exp.EQ | exp.NEQ,
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 8e55b6a..20907db 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -22,6 +22,7 @@ class Databricks(Spark):
"DATEADD": parse_date_delta(exp.DateAdd),
"DATE_ADD": parse_date_delta(exp.DateAdd),
"DATEDIFF": parse_date_delta(exp.DateDiff),
+ "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
}
FACTOR = {
@@ -48,6 +49,9 @@ class Databricks(Spark):
exp.DatetimeDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
),
+ exp.TimestampDiff: lambda self, e: self.func(
+ "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
+ ),
exp.DatetimeTrunc: timestamptrunc_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 6be991b..6e2d190 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import logging
import typing as t
from enum import Enum, auto
from functools import reduce
@@ -7,7 +8,8 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
-from sqlglot.helper import AutoName, flatten, seq_get
+from sqlglot.helper import AutoName, flatten, is_int, seq_get
+from sqlglot.jsonpath import parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
if t.TYPE_CHECKING:
- from sqlglot._typing import B, E
+ from sqlglot._typing import B, E, F
+
+ JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
+
+logger = logging.getLogger("sqlglot")
class Dialects(str, Enum):
@@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- # Delimiters for quotes, identifiers and the corresponding escape characters
+ # Delimiters for string literals and identifiers
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
@@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect):
"""
if (
isinstance(expression, exp.Identifier)
- and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
+ and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
and (
not expression.quoted
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
@@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect):
return expression
+ def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ if isinstance(path, exp.Literal):
+ path_text = path.name
+ if path.is_number:
+ path_text = f"[{path_text}]"
+
+ try:
+ return parse_json_path(path_text)
+ except ParseError as e:
+ logger.warning(f"Invalid JSON path syntax. {str(e)}")
+
+ return path
+
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)
@@ -500,14 +519,12 @@ def if_sql(
return _if_sql
-def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
- return self.binary(expression, "->")
-
+def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
+ this = expression.this
+ if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
+ this.replace(exp.cast(this, "json"))
-def arrow_json_extract_scalar_sql(
- self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
-) -> str:
- return self.binary(expression, "->>")
+ return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
@@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
-def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
- self.unsupported("Properties unsupported")
- return ""
-
-
def no_comment_column_constraint_sql(
self: Generator, expression: exp.CommentColumnConstraint
) -> str:
@@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
-def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
- from sqlglot.optimizer.simplify import simplify
-
- # Makes sure the path will be evaluated correctly at runtime to include the path root.
- # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
- path = expression.expression
- path = exp.func(
- "if",
- exp.func("startswith", path, "'['"),
- exp.func("concat", "'$'", path),
- exp.func("concat", "'$.'", path),
- )
-
- expression.expression.replace(simplify(path))
- return expression
-
-
-def path_to_jsonpath(
- name: str = "JSON_EXTRACT",
-) -> t.Callable[[Generator, exp.GetPath], str]:
- def _transform(self: Generator, expression: exp.GetPath) -> str:
- return rename_func(name)(self, prepend_dollar_to_path(expression))
-
- return _transform
-
-
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
- normalize = lambda identifier: (
- self.dialect.normalize_identifier(identifier).name if identifier else None
- )
+ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
+ return self.dialect.normalize_identifier(identifier).name if identifier else None
targets = {normalize(expression.this.this)}
@@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
)
return self.merge_sql(expression)
+
+
+def parse_json_extract_path(
+ expr_type: t.Type[F], zero_based_indexing: bool = True
+) -> t.Callable[[t.List], F]:
+ def _parse_json_extract_path(args: t.List) -> F:
+ segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
+ for arg in args[1:]:
+ if not isinstance(arg, exp.Literal):
+ # We use the fallback parser because we can't really transpile non-literals safely
+ return expr_type.from_arg_list(args)
+
+ text = arg.name
+ if is_int(text):
+ index = int(text)
+ segments.append(
+ exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
+ )
+ else:
+ segments.append(exp.JSONPathKey(this=text))
+
+ # This is done to avoid failing in the expression validator due to the arg count
+ del args[2:]
+ return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
+
+ return _parse_json_extract_path
+
+
+def json_extract_segments(
+ name: str, quoted_index: bool = True
+) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
+ def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
+ path = expression.expression
+ if not isinstance(path, exp.JSONPath):
+ return rename_func(name)(self, expression)
+
+ segments = []
+ for segment in path.expressions:
+ path = self.sql(segment)
+ if path:
+ if isinstance(segment, exp.JSONPathPart) and (
+ quoted_index or not isinstance(segment, exp.JSONPathSubscript)
+ ):
+ path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
+
+ segments.append(path)
+
+ return self.func(name, expression.this, *segments)
+
+ return _json_extract_segments
+
+
+def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
+ if isinstance(expression.this, exp.JSONPathWildcard):
+ self.unsupported("Unsupported wildcard in JSONPathKey expression")
+
+ return expression.name
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 6e229b3..7a18e8e 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -55,11 +55,14 @@ class Doris(MySQL):
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
- exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToUnix: lambda self,
+ e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
- exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
+ exp.ToChar: lambda self,
+ e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TsOrDsAdd: lambda self,
+ e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 6bca9e7..be23355 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -99,6 +99,7 @@ class Drill(Dialect):
QUERY_HINTS = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
+ SUPPORTS_CREATE_TABLE_LIKE = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -128,10 +129,14 @@ class Drill(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
- exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
- exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
- exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
- exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
+ exp.DateToDi: lambda self,
+ e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
+ exp.DiToDate: lambda self,
+ e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
+ exp.If: lambda self,
+ e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
+ exp.ILike: lambda self,
+ e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@@ -141,7 +146,8 @@ class Drill(Dialect):
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self,
+ e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@@ -149,8 +155,10 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
- exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDsAdd: lambda self,
+ e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
+ exp.TsOrDiToDi: lambda self,
+ e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
def normalize_func(self, name: str) -> str:
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f55ad70..d7ba729 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
- arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
binary_from_function,
bool_xor_sql,
@@ -18,11 +17,9 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
inline_array_sql,
no_comment_column_constraint_sql,
- no_properties_sql,
no_safe_divide_sql,
no_timestamp_sql,
pivot_column_names,
- prepend_dollar_to_path,
regexp_extract_sql,
rename_func,
str_position_sql,
@@ -172,6 +169,18 @@ class DuckDB(Dialect):
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
+ def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ if isinstance(path, exp.Literal):
+ # DuckDB also supports the JSON pointer syntax, where every path starts with a `/`.
+ # Additionally, it allows accessing the back of lists using the `[#-i]` syntax.
+ # This check ensures we'll avoid trying to parse these as JSON paths, which can
+ # either result in a noisy warning or in an invalid representation of the path.
+ path_text = path.name
+ if path_text.startswith("/") or "[#" in path_text:
+ return path
+
+ return super().to_json_path(path)
+
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -229,6 +238,8 @@ class DuckDB(Dialect):
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
"JSON": exp.ParseJSON.from_arg_list,
+ "JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract),
+ "JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
"LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
@@ -319,6 +330,9 @@ class DuckDB(Dialect):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_KEY_VALUE_PAIR_SEP = ","
+ IGNORE_NULLS_IN_FUNC = True
+ JSON_PATH_BRACKETED_KEY_SUPPORTED = False
+ SUPPORTS_CREATE_TABLE_LIKE = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -350,18 +364,18 @@ class DuckDB(Dialect):
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
+ exp.DateToDi: lambda self,
+ e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
- exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
+ exp.DiToDate: lambda self,
+ e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
- exp.JSONBExtract: arrow_json_extract_sql,
- exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: arrow_json_extract_sql,
- exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
@@ -377,7 +391,6 @@ class DuckDB(Dialect):
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
- exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: lambda self, e: self.func(
"REGEXP_REPLACE",
@@ -395,7 +408,8 @@ class DuckDB(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: str_to_time_sql,
- exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.StrToUnix: lambda self,
+ e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
@@ -405,9 +419,11 @@ class DuckDB(Dialect):
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
- exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToStr: lambda self,
+ e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
- exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDiToDi: lambda self,
+ e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
@@ -415,7 +431,8 @@ class DuckDB(Dialect):
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
- exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
+ exp.UnixToStr: lambda self,
+ e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
exp.VariancePop: rename_func("VAR_POP"),
@@ -423,6 +440,13 @@ class DuckDB(Dialect):
exp.Xor: bool_xor_sql,
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ exp.JSONPathWildcard,
+ }
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
@@ -442,11 +466,18 @@ class DuckDB(Dialect):
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
+ # DuckDB doesn't generally support CREATE TABLE .. properties
+ # https://duckdb.org/docs/sql/statements/create_table.html
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION,
- exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ prop: exp.Properties.Location.UNSUPPORTED
+ for prop in generator.Generator.PROPERTIES_LOCATION
}
+ # There are a few exceptions (e.g. temporary tables) which are supported or
+ # can be transpiled to DuckDB, so we explicitly override them accordingly
+ PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
+ PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
+
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
@@ -486,10 +517,6 @@ class DuckDB(Dialect):
expression, sep=sep, tablesample_keyword=tablesample_keyword
)
- def getpath_sql(self, expression: exp.GetPath) -> str:
- expression = prepend_dollar_to_path(expression)
- return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
-
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
unit = expression.text("unit").lower()
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 060f9bd..6337ffd 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -192,6 +192,18 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
return f"TO_DATE({this})"
+def _parse_ignore_nulls(
+ exp_class: t.Type[exp.Expression],
+) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
+ def _parse(args: t.List[exp.Expression]) -> exp.Expression:
+ this = exp_class(this=seq_get(args, 0))
+ if seq_get(args, 1) == exp.true():
+ return exp.IgnoreNulls(this=this)
+ return this
+
+ return _parse
+
+
class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
@@ -298,8 +310,12 @@ class Hive(Dialect):
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "FIRST": _parse_ignore_nulls(exp.First),
+ "FIRST_VALUE": _parse_ignore_nulls(exp.FirstValue),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
+ "LAST": _parse_ignore_nulls(exp.Last),
+ "LAST_VALUE": _parse_ignore_nulls(exp.LastValue),
"LOCATE": locate_to_strposition,
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
@@ -429,6 +445,7 @@ class Hive(Dialect):
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
+ JSON_PATH_SINGLE_QUOTE_ESCAPE = True
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@@ -437,6 +454,13 @@ class Hive(Dialect):
exp.Union,
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ exp.JSONPathWildcard,
+ }
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIT: "BOOLEAN",
@@ -471,9 +495,12 @@ class Hive(Dialect):
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _add_date_sql,
- exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
- exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
- exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
+ exp.DateToDi: lambda self,
+ e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
+ exp.DiToDate: lambda self,
+ e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
+ exp.FileFormatProperty: lambda self,
+ e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
@@ -502,7 +529,8 @@ class Hive(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
- exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
+ exp.Split: lambda self,
+ e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_time_sql,
@@ -514,7 +542,8 @@ class Hive(Dialect):
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),
- exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDiToDi: lambda self,
+ e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _add_date_sql,
exp.TsOrDsDiff: _date_diff_sql,
exp.TsOrDsToDate: _to_date_sql,
@@ -528,8 +557,10 @@ class Hive(Dialect):
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
- exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
- exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
+ exp.ClusteredColumnConstraint: lambda self,
+ e: f"({self.expressions(e, 'this', indent=False)})",
+ exp.NonClusteredColumnConstraint: lambda self,
+ e: f"({self.expressions(e, 'this', indent=False)})",
exp.NotForReplicationColumnConstraint: lambda self, e: "",
exp.OnProperty: lambda self, e: "",
exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY",
@@ -543,6 +574,13 @@ class Hive(Dialect):
exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
+ if isinstance(expression.this, exp.JSONPathWildcard):
+ self.unsupported("Unsupported wildcard in JSONPathKey expression")
+ return ""
+
+ return super()._jsonpathkey_sql(expression)
+
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# Hive has no temporary storage provider (there are hive settings though)
return expression
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 21a9657..661ef7d 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
- arrow_json_extract_scalar_sql,
+ arrow_json_extract_sql,
date_add_interval_sql,
datestrtodate_sql,
format_time_lambda,
@@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
+ parse_date_delta,
parse_date_delta_with_interval,
- path_to_jsonpath,
rename_func,
strposition_to_locate_sql,
)
@@ -306,6 +306,7 @@ class MySQL(Dialect):
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
+ "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
"TO_DAYS": lambda args: exp.paren(
exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@@ -357,6 +358,7 @@ class MySQL(Dialect):
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
"DATABASES": _show_parser("DATABASES"),
+ "SCHEMAS": _show_parser("DATABASES"),
"ENGINE": _show_parser("ENGINE", target=True),
"STORAGE ENGINES": _show_parser("ENGINES"),
"ENGINES": _show_parser("ENGINES"),
@@ -630,6 +632,8 @@ class MySQL(Dialect):
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
+ JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
+ JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
@@ -646,10 +650,10 @@ class MySQL(Dialect):
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
- exp.GetPath: path_to_jsonpath(),
- exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
+ exp.GroupConcat: lambda self,
+ e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
- exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONExtractScalar: arrow_json_extract_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
@@ -672,6 +676,9 @@ class MySQL(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
+ exp.TimestampDiff: lambda self, e: self.func(
+ "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
+ ),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 4591d59..0c0d750 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -199,7 +199,8 @@ class Oracle(Dialect):
transforms.eliminate_qualify,
]
),
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self,
+ e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
@@ -208,7 +209,8 @@ class Oracle(Dialect):
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
- exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
+ exp.UnixToTime: lambda self,
+ e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 87f6b02..0404c78 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -7,11 +7,11 @@ from sqlglot.dialects.dialect import (
DATE_ADD_OR_SUB,
Dialect,
any_value_to_max_sql,
- arrow_json_extract_scalar_sql,
- arrow_json_extract_sql,
bool_xor_sql,
datestrtodate_sql,
format_time_lambda,
+ json_extract_segments,
+ json_path_key_only_name,
max_or_greatest,
merge_without_target_sql,
min_or_least,
@@ -20,6 +20,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_pivot_sql,
no_trycast_sql,
+ parse_json_extract_path,
parse_timestamp_trunc,
rename_func,
str_position_sql,
@@ -292,6 +293,8 @@ class Postgres(Dialect):
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
+ "JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract),
+ "JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar),
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
@@ -375,8 +378,15 @@ class Postgres(Dialect):
TABLESAMPLE_SIZE_IS_ROWS = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
- # https://www.postgresql.org/docs/current/sql-createtable.html
+ JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
SUPPORTS_UNLOGGED_TABLES = True
+ LIKE_PROPERTY_INSIDE_SCHEMA = True
+
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -412,11 +422,14 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
- exp.JSONExtract: arrow_json_extract_sql,
- exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"),
+ exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
+ exp.JSONPathKey: json_path_key_only_name,
+ exp.JSONPathRoot: lambda *_: "",
+ exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this),
exp.LastDay: no_last_day_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
@@ -443,7 +456,8 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self,
+ e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 6cc6030..8691192 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
- path_to_jsonpath,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@@ -150,7 +149,7 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
return expression
-def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
+def _first_last_sql(self: Presto.Generator, expression: exp.Func) -> str:
"""
Trino doesn't support FIRST / LAST as functions, but they're valid in the context
of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
@@ -292,6 +291,7 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
LIMIT_ONLY_LITERALS = True
SUPPORTS_SINGLE_ARG_CONCAT = False
+ LIKE_PROPERTY_INSIDE_SCHEMA = True
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@@ -324,12 +324,18 @@ class Presto(Dialect):
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
- exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.AtTimeZone: rename_func("AT_TIMEZONE"),
+ exp.BitwiseAnd: lambda self,
+ e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseLeftShift: lambda self,
+ e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
- exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseOr: lambda self,
+ e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseRightShift: lambda self,
+ e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseXor: lambda self,
+ e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
@@ -344,7 +350,8 @@ class Presto(Dialect):
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
+ exp.DateToDi: lambda self,
+ e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "DAY"),
@@ -352,12 +359,14 @@ class Presto(Dialect):
e.this,
),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
- exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
+ exp.DiToDate: lambda self,
+ e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
- exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
- exp.GetPath: path_to_jsonpath(),
+ exp.FirstValue: _first_last_sql,
+ exp.FromTimeZone: lambda self,
+ e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@@ -368,6 +377,7 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Last: _first_last_sql,
+ exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
@@ -394,26 +404,33 @@ class Presto(Dialect):
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
- exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.StrToUnix: lambda self,
+ e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
- exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToUnix: lambda self,
+ e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
+ exp.TimeToStr: lambda self,
+ e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
- exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.ToChar: lambda self,
+ e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
- exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDiToDi: lambda self,
+ e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
- exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
+ exp.UnixToStr: lambda self,
+ e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
- exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
+ exp.UnixToTimeStr: lambda self,
+ e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 7194d81..a64c1d4 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
concat_ws_to_dpipe_sql,
date_delta_sql,
generatedasidentitycolumnconstraint_sql,
+ json_extract_segments,
no_tablesample_sql,
rename_func,
)
@@ -20,10 +21,6 @@ if t.TYPE_CHECKING:
from sqlglot._typing import E
-def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
- return f'{self.sql(expression, "this")}."{expression.expression.name}"'
-
-
def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _parse_delta(args: t.List) -> E:
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
@@ -62,6 +59,7 @@ class Redshift(Postgres):
"DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
"DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
"DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
+ "GETDATE": exp.CurrentTimestamp.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@@ -69,6 +67,7 @@ class Redshift(Postgres):
NO_PAREN_FUNCTION_PARSERS = {
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
"APPROXIMATE": lambda self: self._parse_approximate_count(),
+ "SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True),
}
def _parse_table(
@@ -77,6 +76,7 @@ class Redshift(Postgres):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
+ is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
unpivot = self._match(TokenType.UNPIVOT)
@@ -85,6 +85,7 @@ class Redshift(Postgres):
joins=joins,
alias_tokens=alias_tokens,
parse_bracket=parse_bracket,
+ is_db_reference=is_db_reference,
)
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
@@ -153,7 +154,6 @@ class Redshift(Postgres):
**Postgres.Tokenizer.KEYWORDS,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
- "SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
@@ -180,31 +180,29 @@ class Redshift(Postgres):
exp.DataType.Type.VARBINARY: "VARBYTE",
}
- PROPERTIES_LOCATION = {
- **Postgres.Generator.PROPERTIES_LOCATION,
- exp.LikeProperty: exp.Properties.Location.POST_WITH,
- }
-
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.ConcatWs: concat_ws_to_dpipe_sql,
- exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
- exp.CurrentTimestamp: lambda self, e: "SYSDATE",
+ exp.ApproxDistinct: lambda self,
+ e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
+ exp.CurrentTimestamp: lambda self, e: (
+ "SYSDATE" if e.args.get("transaction") else "GETDATE()"
+ ),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
- exp.JSONExtract: _json_sql,
- exp.JSONExtractScalar: _json_sql,
+ exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
- exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.SortKeyProperty: lambda self,
+ e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
@@ -228,6 +226,13 @@ class Redshift(Postgres):
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ if expression.is_type(exp.DataType.Type.JSON):
+ # Redshift doesn't support a JSON type, so casting to it is treated as a noop
+ return self.sql(expression, "this")
+
+ return super().cast_sql(expression, safe_prefix=safe_prefix)
+
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 281167d..37f9761 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -21,19 +21,13 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import seq_get
+from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
-def _check_int(s: str) -> bool:
- if s[0] in ("-", "+"):
- return s[1:].isdigit()
- return s.isdigit()
-
-
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
@@ -53,7 +47,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
return exp.TimeStrToTime.from_arg_list(args)
if first_arg.is_string:
- if _check_int(first_arg.this):
+ if is_int(first_arg.this):
# case: <integer>
return exp.UnixToTime.from_arg_list(args)
@@ -241,7 +235,6 @@ DATE_PART_MAPPING = {
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
- "NSECONDS": "NANOSECOND",
"EPOCH": "EPOCH_SECOND",
"EPOCH_SECONDS": "EPOCH_SECOND",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
@@ -291,7 +284,9 @@ def _parse_colon_get_path(
path = exp.Literal.string(path.sql(dialect="snowflake"))
# The extraction operator : is left-associative
- this = self.expression(exp.GetPath, this=this, expression=path)
+ this = self.expression(
+ exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
+ )
if target_type:
this = exp.cast(this, target_type)
@@ -411,6 +406,9 @@ class Snowflake(Dialect):
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"FLATTEN": exp.Explode.from_arg_list,
+ "GET_PATH": lambda args, dialect: exp.JSONExtract(
+ this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
+ ),
"IFF": exp.If.from_arg_list,
"LAST_DAY": lambda args: exp.LastDay(
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
@@ -474,6 +472,8 @@ class Snowflake(Dialect):
"TERSE SCHEMAS": _show_parser("SCHEMAS"),
"OBJECTS": _show_parser("OBJECTS"),
"TERSE OBJECTS": _show_parser("OBJECTS"),
+ "TABLES": _show_parser("TABLES"),
+ "TERSE TABLES": _show_parser("TABLES"),
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
@@ -534,7 +534,9 @@ class Snowflake(Dialect):
return table
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ def _parse_table_parts(
+ self, schema: bool = False, is_db_reference: bool = False
+ ) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
if self._match(TokenType.STRING, advance=False):
table = self._parse_string()
@@ -550,7 +552,9 @@ class Snowflake(Dialect):
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILE_FORMAT", "=>"):
- file_format = self._parse_string() or super()._parse_table_parts()
+ file_format = self._parse_string() or super()._parse_table_parts(
+ is_db_reference=is_db_reference
+ )
elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
else:
@@ -560,7 +564,7 @@ class Snowflake(Dialect):
table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
else:
- table = super()._parse_table_parts(schema=schema)
+ table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
return self._parse_at_before(table)
@@ -587,6 +591,8 @@ class Snowflake(Dialect):
# which is syntactically valid but has no effect on the output
terse = self._tokens[self._index - 2].text.upper() == "TERSE"
+ history = self._match_text_seq("HISTORY")
+
like = self._parse_string() if self._match(TokenType.LIKE) else None
if self._match(TokenType.IN):
@@ -597,7 +603,7 @@ class Snowflake(Dialect):
if self._curr:
scope = self._parse_table_parts()
elif self._curr:
- scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
+ scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE"
scope = self._parse_table_parts()
return self.expression(
@@ -605,6 +611,7 @@ class Snowflake(Dialect):
**{
"terse": terse,
"this": this,
+ "history": history,
"like": like,
"scope": scope,
"scope_kind": scope_kind,
@@ -715,8 +722,10 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
- exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
+ exp.JSONExtract: rename_func("GET_PATH"),
+ exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"),
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
+ exp.JSONPathRoot: lambda *_: "",
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -745,7 +754,8 @@ class Snowflake(Dialect):
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self,
+ e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Struct: lambda self, e: self.func(
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
@@ -771,6 +781,12 @@ class Snowflake(Dialect):
exp.Xor: rename_func("BOOLXOR"),
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
@@ -841,6 +857,7 @@ class Snowflake(Dialect):
def show_sql(self, expression: exp.Show) -> str:
terse = "TERSE " if expression.args.get("terse") else ""
+ history = " HISTORY" if expression.args.get("history") else ""
like = self.sql(expression, "like")
like = f" LIKE {like}" if like else ""
@@ -861,9 +878,7 @@ class Snowflake(Dialect):
if from_:
from_ = f" FROM {from_}"
- return (
- f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
- )
+ return f"SHOW {terse}{expression.name}{history}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 624f76e..4c5c131 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
+from sqlglot.dialects.hive import _parse_ignore_nulls
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
@@ -45,9 +46,7 @@ class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
- "ANY_VALUE": lambda args: exp.AnyValue(
- this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
- ),
+ "ANY_VALUE": _parse_ignore_nulls(exp.AnyValue),
"DATEDIFF": _parse_datediff,
}
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index e4bb30e..9378d99 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -187,8 +187,10 @@ class Spark2(Hive):
TRANSFORMS = {
**Hive.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
- exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
- exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.ArraySum: lambda self,
+ e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.AtTimeZone: lambda self,
+ e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.DateFromParts: rename_func("MAKE_DATE"),
@@ -198,7 +200,8 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
- exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.FromTimeZone: lambda self,
+ e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@@ -212,7 +215,8 @@ class Spark2(Hive):
e.args.get("position"),
),
exp.StrToDate: _str_to_date,
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self,
+ e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 244a96e..b292c81 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
any_value_to_max_sql,
- arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
count_if_to_sum,
@@ -28,6 +27,12 @@ def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
return self.func("DATE", expression.this, modifier)
+def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> str:
+ if expression.expressions:
+ return self.function_fallback_sql(expression)
+ return arrow_json_extract_sql(self, expression)
+
+
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
@@ -85,6 +90,14 @@ class SQLite(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
+ JSON_PATH_BRACKETED_KEY_SUPPORTED = False
+ SUPPORTS_CREATE_TABLE_LIKE = False
+
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -120,10 +133,8 @@ class SQLite(Dialect):
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.ILike: no_ilike_sql,
- exp.JSONExtract: arrow_json_extract_sql,
- exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
- exp.JSONBExtract: arrow_json_extract_sql,
- exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONExtract: _json_extract_sql,
+ exp.JSONExtractScalar: arrow_json_extract_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
@@ -141,11 +152,18 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
+ # SQLite doesn't generally support CREATE TABLE .. properties
+ # https://www.sqlite.org/lang_createtable.html
PROPERTIES_LOCATION = {
- k: exp.Properties.Location.UNSUPPORTED
- for k, v in generator.Generator.PROPERTIES_LOCATION.items()
+ prop: exp.Properties.Location.UNSUPPORTED
+ for prop in generator.Generator.PROPERTIES_LOCATION
}
+ # There are a few exceptions (e.g. temporary tables) which are supported or
+ # can be transpiled to SQLite, so we explicitly override them accordingly
+ PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
+ PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
+
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 2dba1c1..8838f34 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -44,12 +44,14 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
- exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToUnix: lambda self,
+ e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.TimeStrToDate: rename_func("TO_DATE"),
- exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.UnixToStr: lambda self,
+ e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 6dbad15..7f9a11a 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -200,7 +200,8 @@ class Teradata(Dialect):
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
- exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
+ exp.StrToDate: lambda self,
+ e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index eddb70a..1bbed67 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -11,9 +11,16 @@ class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
- exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.ArraySum: lambda self,
+ e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.Merge: merge_without_target_sql,
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
+
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a5e04da..70ea97e 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -14,7 +14,6 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
parse_date_delta,
- path_to_jsonpath,
rename_func,
timestrtotime_sql,
trim_sql,
@@ -266,13 +265,32 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
-def _parse_len(args: t.List) -> exp.Length:
- this = seq_get(args, 0)
+def _parse_as_text(
+ klass: t.Type[exp.Expression],
+) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
+ def _parse(args: t.List[exp.Expression]) -> exp.Expression:
+ this = seq_get(args, 0)
+
+ if this and not this.is_string:
+ this = exp.cast(this, exp.DataType.Type.TEXT)
+
+ expression = seq_get(args, 1)
+ kwargs = {"this": this}
+
+ if expression:
+ kwargs["expression"] = expression
- if this and not this.is_string:
- this = exp.cast(this, exp.DataType.Type.TEXT)
+ return klass(**kwargs)
- return exp.Length(this=this)
+ return _parse
+
+
+def _json_extract_sql(
+ self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
+) -> str:
+ json_query = rename_func("JSON_QUERY")(self, expression)
+ json_value = rename_func("JSON_VALUE")(self, expression)
+ return self.func("ISNULL", json_query, json_value)
class TSQL(Dialect):
@@ -441,8 +459,11 @@ class TSQL(Dialect):
"HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
- "JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
- "LEN": _parse_len,
+ "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
+ "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
+ "LEN": _parse_as_text(exp.Length),
+ "LEFT": _parse_as_text(exp.Left),
+ "RIGHT": _parse_as_text(exp.Right),
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@@ -677,6 +698,7 @@ class TSQL(Dialect):
SUPPORTS_SINGLE_ARG_CONCAT = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
+ JSON_PATH_BRACKETED_KEY_SUPPORTED = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -688,6 +710,12 @@ class TSQL(Dialect):
exp.Update,
}
+ SUPPORTED_JSON_PATH_PARTS = {
+ exp.JSONPathKey,
+ exp.JSONPathRoot,
+ exp.JSONPathSubscript,
+ }
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
@@ -712,9 +740,10 @@ class TSQL(Dialect):
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
- exp.GetPath: path_to_jsonpath("JSON_VALUE"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
+ exp.JSONExtract: _json_extract_sql,
+ exp.JSONExtractScalar: _json_extract_sql,
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
@@ -831,15 +860,21 @@ class TSQL(Dialect):
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)
+ like_property = expression.find(exp.LikeProperty)
+ if like_property:
+ ctas_expression = like_property.this
+ else:
+ ctas_expression = expression.expression
+
table = expression.find(exp.Table)
# Convert CTAS statement to SELECT .. INTO ..
- if kind == "TABLE" and expression.expression:
- ctas_with = expression.expression.args.get("with")
+ if kind == "TABLE" and ctas_expression:
+ ctas_with = ctas_expression.args.get("with")
if ctas_with:
ctas_with = ctas_with.pop()
- subquery = expression.expression
+ subquery = ctas_expression
if isinstance(subquery, exp.Subqueryable):
subquery = subquery.subquery()
@@ -847,6 +882,9 @@ class TSQL(Dialect):
select_into.set("into", exp.Into(this=table))
select_into.set("with", ctas_with)
+ if like_property:
+ select_into.limit(0, copy=False)
+
sql = self.sql(select_into)
if exists:
@@ -937,9 +975,19 @@ class TSQL(Dialect):
return f"CONSTRAINT {this} {expressions}"
def length_sql(self, expression: exp.Length) -> str:
+ return self._uncast_text(expression, "LEN")
+
+ def right_sql(self, expression: exp.Right) -> str:
+ return self._uncast_text(expression, "RIGHT")
+
+ def left_sql(self, expression: exp.Left) -> str:
+ return self._uncast_text(expression, "LEFT")
+
+ def _uncast_text(self, expression: exp.Expression, name: str) -> str:
this = expression.this
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
this_sql = self.sql(this, "this")
else:
this_sql = self.sql(this)
- return self.func("LEN", this_sql)
+ expression_sql = self.sql(expression, "expression")
+ return self.func(name, this_sql, expression_sql if expression_sql else None)
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index 304981b..c8f9148 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -10,7 +10,6 @@ import logging
import time
import typing as t
-from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
@@ -23,7 +22,6 @@ logger = logging.getLogger("sqlglot")
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
- from sqlglot.executor.table import Tables
from sqlglot.expressions import Expression
from sqlglot.schema import Schema
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index d7952c1..e4c4040 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -44,9 +44,9 @@ class Context:
for other in self.tables.values():
if self._table.columns != other.columns:
- raise Exception(f"Columns are different.")
+ raise Exception("Columns are different.")
if len(self._table.rows) != len(other.rows):
- raise Exception(f"Rows are different.")
+ raise Exception("Rows are different.")
return self._table
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 6c01edc..218a8e0 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -6,7 +6,7 @@ from functools import wraps
from sqlglot import exp
from sqlglot.generator import Generator
-from sqlglot.helper import PYTHON_VERSION
+from sqlglot.helper import PYTHON_VERSION, is_int, seq_get
class reverse_key:
@@ -143,6 +143,22 @@ def arrayjoin(this, expression, null=None):
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
+@null_if_any("this", "expression")
+def jsonextract(this, expression):
+ for path_segment in expression:
+ if isinstance(this, dict):
+ this = this.get(path_segment)
+ elif isinstance(this, list) and is_int(path_segment):
+ this = seq_get(this, int(path_segment))
+ else:
+ raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
+
+ if this is None:
+ break
+
+ return this
+
+
ENV = {
"exp": exp,
# aggs
@@ -175,12 +191,12 @@ ENV = {
"DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
- "GETPATH": null_if_any(lambda this, e: this.get(e)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IF": lambda predicate, true, false: true if predicate else false,
"INTDIV": null_if_any(lambda e, this: e // this),
"INTERVAL": interval,
+ "JSONEXTRACT": jsonextract,
"LEFT": null_if_any(lambda this, e: this[:e]),
"LIKE": null_if_any(
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 7ff9608..c0becbe 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -9,7 +9,7 @@ from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import RowReader, Table
-from sqlglot.helper import csv_reader, subclasses
+from sqlglot.helper import csv_reader, ensure_list, subclasses
class PythonExecutor:
@@ -368,7 +368,7 @@ def _rename(self, e):
if isinstance(e, exp.Func) and e.is_var_len_args:
*head, tail = values
- return self.func(e.key, *head, *tail)
+ return self.func(e.key, *head, *ensure_list(tail))
return self.func(e.key, *values)
except Exception as ex:
@@ -429,18 +429,24 @@ class Python(Dialect):
exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
- exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
+ exp.Column: lambda self,
+ e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Concat: lambda self, e: self.func(
"SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
),
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Div: _div_sql,
- exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
- exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
+ exp.Extract: lambda self,
+ e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
+ exp.In: lambda self,
+ e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
exp.Is: lambda self, e: (
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
),
+ exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
+ exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
+ exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index a95a73e..3234c99 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -29,6 +29,7 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_collection,
ensure_list,
+ is_int,
seq_get,
subclasses,
)
@@ -175,13 +176,7 @@ class Expression(metaclass=_Expression):
"""
Checks whether a Literal expression is an integer.
"""
- if self.is_number:
- try:
- int(self.name)
- return True
- except ValueError:
- pass
- return False
+ return self.is_number and is_int(self.name)
@property
def is_star(self) -> bool:
@@ -493,8 +488,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
- if not type(node) is self.__class__:
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__):
+ if type(node) is not self.__class__:
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
def __str__(self) -> str:
@@ -553,10 +548,12 @@ class Expression(metaclass=_Expression):
return new_node
@t.overload
- def replace(self, expression: E) -> E: ...
+ def replace(self, expression: E) -> E:
+ ...
@t.overload
- def replace(self, expression: None) -> None: ...
+ def replace(self, expression: None) -> None:
+ ...
def replace(self, expression):
"""
@@ -610,7 +607,8 @@ class Expression(metaclass=_Expression):
>>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql()
'SELECT x, z FROM y'
"""
- assert isinstance(self, type_)
+ if not isinstance(self, type_):
+ raise AssertionError(f"{self} is not {type_}.")
return self
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
@@ -1133,6 +1131,7 @@ class SetItem(Expression):
class Show(Expression):
arg_types = {
"this": True,
+ "history": False,
"terse": False,
"target": False,
"offset": False,
@@ -1676,7 +1675,6 @@ class Index(Expression):
"amp": False, # teradata
"include": False,
"partition_by": False, # teradata
- "where": False, # postgres partial indexes
}
@@ -2573,7 +2571,7 @@ class HistoricalData(Expression):
class Table(Expression):
arg_types = {
- "this": True,
+ "this": False,
"alias": False,
"db": False,
"catalog": False,
@@ -3664,6 +3662,7 @@ class DataType(Expression):
BINARY = auto()
BIT = auto()
BOOLEAN = auto()
+ BPCHAR = auto()
CHAR = auto()
DATE = auto()
DATE32 = auto()
@@ -3805,6 +3804,7 @@ class DataType(Expression):
dtype: DATA_TYPE,
dialect: DialectType = None,
udt: bool = False,
+ copy: bool = True,
**kwargs,
) -> DataType:
"""
@@ -3815,7 +3815,8 @@ class DataType(Expression):
dialect: the dialect to use for parsing `dtype`, in case it's a string.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
DataType, thus creating a user-defined type.
- kawrgs: additional arguments to pass in the constructor of DataType.
+ copy: whether or not to copy the data type.
+ kwargs: additional arguments to pass in the constructor of DataType.
Returns:
The constructed DataType object.
@@ -3837,7 +3838,7 @@ class DataType(Expression):
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
- return dtype
+ return maybe_copy(dtype, copy)
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
@@ -3855,7 +3856,7 @@ class DataType(Expression):
True, if and only if there is a type in `dtypes` which is equal to this DataType.
"""
for dtype in dtypes:
- other = DataType.build(dtype, udt=True)
+ other = DataType.build(dtype, copy=False, udt=True)
if (
other.expressions
@@ -4001,7 +4002,7 @@ class Dot(Binary):
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
if len(expressions) < 2:
- raise ValueError(f"Dot requires >= 2 expressions.")
+ raise ValueError("Dot requires >= 2 expressions.")
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
@@ -4128,10 +4129,6 @@ class Sub(Binary):
pass
-class ArrayOverlaps(Binary):
- pass
-
-
# Unary Expressions
# (NOT a)
class Unary(Condition):
@@ -4469,6 +4466,10 @@ class ArrayJoin(Func):
arg_types = {"this": True, "expression": True, "null": False}
+class ArrayOverlaps(Binary, Func):
+ pass
+
+
class ArraySize(Func):
arg_types = {"this": True, "expression": False}
@@ -4490,15 +4491,37 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
- arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
+ arg_types = {"this": True, "having": False, "max": False}
+
+
+class Lag(AggFunc):
+ arg_types = {"this": True, "offset": False, "default": False}
+
+
+class Lead(AggFunc):
+ arg_types = {"this": True, "offset": False, "default": False}
+
+
+# some dialects have a distinction between first and first_value, usually first is an aggregate func
+# and first_value is a window func
+class First(AggFunc):
+ pass
+
+
+class Last(AggFunc):
+ pass
+
+
+class FirstValue(AggFunc):
+ pass
-class First(Func):
- arg_types = {"this": True, "ignore_nulls": False}
+class LastValue(AggFunc):
+ pass
-class Last(Func):
- arg_types = {"this": True, "ignore_nulls": False}
+class NthValue(AggFunc):
+ arg_types = {"this": True, "offset": True}
class Case(Func):
@@ -4611,7 +4634,7 @@ class CurrentTime(Func):
class CurrentTimestamp(Func):
- arg_types = {"this": False}
+ arg_types = {"this": False, "transaction": False}
class CurrentUser(Func):
@@ -4712,6 +4735,7 @@ class TimestampSub(Func, TimeUnit):
class TimestampDiff(Func, TimeUnit):
+ _sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"]
arg_types = {"this": True, "expression": True, "unit": False}
@@ -4857,6 +4881,59 @@ class IsInf(Func):
_sql_names = ["IS_INF", "ISINF"]
+class JSONPath(Expression):
+ arg_types = {"expressions": True}
+
+ @property
+ def output_name(self) -> str:
+ last_segment = self.expressions[-1].this
+ return last_segment if isinstance(last_segment, str) else ""
+
+
+class JSONPathPart(Expression):
+ arg_types = {}
+
+
+class JSONPathFilter(JSONPathPart):
+ arg_types = {"this": True}
+
+
+class JSONPathKey(JSONPathPart):
+ arg_types = {"this": True}
+
+
+class JSONPathRecursive(JSONPathPart):
+ arg_types = {"this": False}
+
+
+class JSONPathRoot(JSONPathPart):
+ pass
+
+
+class JSONPathScript(JSONPathPart):
+ arg_types = {"this": True}
+
+
+class JSONPathSlice(JSONPathPart):
+ arg_types = {"start": False, "end": False, "step": False}
+
+
+class JSONPathSelector(JSONPathPart):
+ arg_types = {"this": True}
+
+
+class JSONPathSubscript(JSONPathPart):
+ arg_types = {"this": True}
+
+
+class JSONPathUnion(JSONPathPart):
+ arg_types = {"expressions": True}
+
+
+class JSONPathWildcard(JSONPathPart):
+ pass
+
+
class FormatJson(Expression):
pass
@@ -4940,18 +5017,30 @@ class JSONBContains(Binary):
class JSONExtract(Binary, Func):
+ arg_types = {"this": True, "expression": True, "expressions": False}
_sql_names = ["JSON_EXTRACT"]
+ is_var_len_args = True
+
+ @property
+ def output_name(self) -> str:
+ return self.expression.output_name if not self.expressions else ""
-class JSONExtractScalar(JSONExtract):
+class JSONExtractScalar(Binary, Func):
+ arg_types = {"this": True, "expression": True, "expressions": False}
_sql_names = ["JSON_EXTRACT_SCALAR"]
+ is_var_len_args = True
+
+ @property
+ def output_name(self) -> str:
+ return self.expression.output_name
-class JSONBExtract(JSONExtract):
+class JSONBExtract(Binary, Func):
_sql_names = ["JSONB_EXTRACT"]
-class JSONBExtractScalar(JSONExtract):
+class JSONBExtractScalar(Binary, Func):
_sql_names = ["JSONB_EXTRACT_SCALAR"]
@@ -4972,15 +5061,6 @@ class ParseJSON(Func):
is_var_len_args = True
-# https://docs.snowflake.com/en/sql-reference/functions/get_path
-class GetPath(Func):
- arg_types = {"this": True, "expression": True}
-
- @property
- def output_name(self) -> str:
- return self.expression.output_name
-
-
class Least(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -5476,6 +5556,8 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
+JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
+
# Helpers
@t.overload
@@ -5487,7 +5569,8 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
-) -> E: ...
+) -> E:
+ ...
@t.overload
@@ -5499,7 +5582,8 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
-) -> E: ...
+) -> E:
+ ...
def maybe_parse(
@@ -5539,7 +5623,7 @@ def maybe_parse(
return sql_or_expression
if sql_or_expression is None:
- raise ParseError(f"SQL cannot be None")
+ raise ParseError("SQL cannot be None")
import sqlglot
@@ -5551,11 +5635,13 @@ def maybe_parse(
@t.overload
-def maybe_copy(instance: None, copy: bool = True) -> None: ...
+def maybe_copy(instance: None, copy: bool = True) -> None:
+ ...
@t.overload
-def maybe_copy(instance: E, copy: bool = True) -> E: ...
+def maybe_copy(instance: E, copy: bool = True) -> E:
+ ...
def maybe_copy(instance, copy=True):
@@ -6174,17 +6260,19 @@ def paren(expression: ExpOrStr, copy: bool = True) -> Paren:
return Paren(this=maybe_parse(expression, copy=copy))
-SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
+SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
-def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
+def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
+ ...
@t.overload
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
-) -> Identifier: ...
+) -> Identifier:
+ ...
def to_identifier(name, quoted=None, copy=True):
@@ -6256,11 +6344,13 @@ def to_interval(interval: str | Literal) -> Interval:
@t.overload
-def to_table(sql_path: str | Table, **kwargs) -> Table: ...
+def to_table(sql_path: str | Table, **kwargs) -> Table:
+ ...
@t.overload
-def to_table(sql_path: None, **kwargs) -> None: ...
+def to_table(sql_path: None, **kwargs) -> None:
+ ...
def to_table(
@@ -6460,7 +6550,7 @@ def column(
return this
-def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
+def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@@ -6470,12 +6560,13 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
Args:
expression: The expression to cast.
to: The datatype to cast to.
+ copy: Whether or not to copy the supplied expressions.
Returns:
The new Cast instance.
"""
- expression = maybe_parse(expression, **opts)
- data_type = DataType.build(to, **opts)
+ expression = maybe_parse(expression, copy=copy, **opts)
+ data_type = DataType.build(to, copy=copy, **opts)
expression = Cast(this=expression, to=data_type)
expression.type = data_type
return expression
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 8e3ff9b..568dcb4 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -9,6 +9,7 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
+from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@@ -21,7 +22,18 @@ logger = logging.getLogger("sqlglot")
ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)")
-class Generator:
+class _Generator(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+
+ # Remove transforms that correspond to unsupported JSONPathPart expressions
+ for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS:
+ klass.TRANSFORMS.pop(part, None)
+
+ return klass
+
+
+class Generator(metaclass=_Generator):
"""
Generator converts a given syntax tree to the corresponding SQL string.
@@ -58,19 +70,23 @@ class Generator:
Default: True
"""
- TRANSFORMS = {
- exp.DateAdd: lambda self, e: self.func(
- "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
- ),
- exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
+ TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
+ **JSON_PATH_PART_TRANSFORMS,
+ exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
+ exp.CaseSpecificColumnConstraint: lambda self,
+ e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
- exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
+ exp.CharacterSetProperty: lambda self,
+ e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
- exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
+ exp.ClusteredColumnConstraint: lambda self,
+ e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
- exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
- exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
+ exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
+ exp.DateAdd: lambda self, e: self.func(
+ "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
+ ),
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
@@ -85,29 +101,33 @@ class Generator:
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
+ exp.NonClusteredColumnConstraint: lambda self,
+ e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
- exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
- exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
+ exp.OnCommitProperty: lambda self,
+ e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
- exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
+ exp.RemoteWithConnectionModelProperty: lambda self,
+ e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
- exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
+ exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlReadWriteProperty: lambda self, e: e.name,
- exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
+ exp.SqlSecurityProperty: lambda self,
+ e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda self, e: e.name,
- exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
+ exp.TemporaryProperty: lambda self, e: "TEMPORARY",
+ exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
- exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
- exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
- exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
+ exp.TransientProperty: lambda self, e: "TRANSIENT",
+ exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
@@ -117,6 +137,10 @@ class Generator:
# True: Full Support, None: No support, False: No support in window specifications
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
+ # Whether or not ignore nulls is inside the agg or outside.
+ # FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER
+ IGNORE_NULLS_IN_FUNC = False
+
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
@@ -266,6 +290,24 @@ class Generator:
# Whether or not UNLOGGED tables can be created
SUPPORTS_UNLOGGED_TABLES = False
+ # Whether or not the CREATE TABLE LIKE statement is supported
+ SUPPORTS_CREATE_TABLE_LIKE = True
+
+ # Whether or not the LikeProperty needs to be specified inside of the schema clause
+ LIKE_PROPERTY_INSIDE_SCHEMA = False
+
+ # Whether or not the JSON extraction operators expect a value of type JSON
+ JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
+
+ # Whether or not bracketed keys like ["foo"] are supported in JSON paths
+ JSON_PATH_BRACKETED_KEY_SUPPORTED = True
+
+ # Whether or not to escape keys using single quotes in JSON paths
+ JSON_PATH_SINGLE_QUOTE_ESCAPE = False
+
+ # The JSONPathPart expressions supported by this dialect
+ SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy()
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -641,8 +683,6 @@ class Generator:
if callable(transform):
sql = transform(self, expression)
- elif transform:
- sql = transform
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
@@ -802,7 +842,7 @@ class Generator:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
- return f"PRIMARY KEY"
+ return "PRIMARY KEY"
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
this = self.sql(expression, "this")
@@ -1218,9 +1258,21 @@ class Generator:
return f"{property_name}={self.sql(expression, 'this')}"
def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
- options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
- options = f" {options}" if options else ""
- return f"LIKE {self.sql(expression, 'this')}{options}"
+ if self.SUPPORTS_CREATE_TABLE_LIKE:
+ options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
+ options = f" {options}" if options else ""
+
+ like = f"LIKE {self.sql(expression, 'this')}{options}"
+ if self.LIKE_PROPERTY_INSIDE_SCHEMA and not isinstance(expression.parent, exp.Schema):
+ like = f"({like})"
+
+ return like
+
+ if expression.expressions:
+ self.unsupported("Transpilation of LIKE property options is unsupported")
+
+ select = exp.select("*").from_(expression.this).limit(0)
+ return f"AS {self.sql(select)}"
def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str:
no = "NO " if expression.args.get("no") else ""
@@ -2367,6 +2419,31 @@ class Generator:
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}"
+ def jsonpath_sql(self, expression: exp.JSONPath) -> str:
+ path = self.expressions(expression, sep="", flat=True).lstrip(".")
+ return f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
+
+ def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str:
+ if isinstance(expression, exp.JSONPathPart):
+ transform = self.TRANSFORMS.get(expression.__class__)
+ if not callable(transform):
+ self.unsupported(f"Unsupported JSONPathPart type {expression.__class__.__name__}")
+ return ""
+
+ return transform(self, expression)
+
+ if isinstance(expression, int):
+ return str(expression)
+
+ if self.JSON_PATH_SINGLE_QUOTE_ESCAPE:
+ escaped = expression.replace("'", "\\'")
+ escaped = f"\\'{expression}\\'"
+ else:
+ escaped = expression.replace('"', '\\"')
+ escaped = f'"{escaped}"'
+
+ return escaped
+
def formatjson_sql(self, expression: exp.FormatJson) -> str:
return f"{self.sql(expression, 'this')} FORMAT JSON"
@@ -2620,6 +2697,9 @@ class Generator:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
+ def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
+ return self.func("CURRENT_TIMESTAMP", expression.this)
+
def collate_sql(self, expression: exp.Collate) -> str:
if self.COLLATE_IS_FUNC:
return self.function_fallback_sql(expression)
@@ -2761,10 +2841,20 @@ class Generator:
return f"DISTINCT{this}{on}"
def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
- return f"{self.sql(expression, 'this')} IGNORE NULLS"
+ return self._embed_ignore_nulls(expression, "IGNORE NULLS")
def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
- return f"{self.sql(expression, 'this')} RESPECT NULLS"
+ return self._embed_ignore_nulls(expression, "RESPECT NULLS")
+
+ def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
+ if self.IGNORE_NULLS_IN_FUNC:
+ this = expression.find(exp.AggFunc)
+ if this:
+ sql = self.sql(this)
+ sql = sql[:-1] + f" {text})"
+ return sql
+
+ return f"{self.sql(expression, 'this')} {text}"
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
@@ -2935,7 +3025,7 @@ class Generator:
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
- return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
+ return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
def text_width(self, args: t.Iterable) -> int:
@@ -3279,6 +3369,22 @@ class Generator:
return self.func("LAST_DAY", expression.this)
+ def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
+ this = expression.this
+ if isinstance(this, exp.JSONPathWildcard):
+ this = self.json_path_part(this)
+ return f".{this}" if this else ""
+
+ if exp.SAFE_IDENTIFIER_RE.match(this):
+ return f".{this}"
+
+ this = self.json_path_part(this)
+ return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
+
+ def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
+ this = self.json_path_part(expression.this)
+ return f"[{this}]" if this else ""
+
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index de737be..9799fe2 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -53,11 +53,13 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
@t.overload
-def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
+def ensure_list(value: t.Collection[T]) -> t.List[T]:
+ ...
@t.overload
-def ensure_list(value: T) -> t.List[T]: ...
+def ensure_list(value: T) -> t.List[T]:
+ ...
def ensure_list(value):
@@ -79,11 +81,13 @@ def ensure_list(value):
@t.overload
-def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
+def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
+ ...
@t.overload
-def ensure_collection(value: T) -> t.Collection[T]: ...
+def ensure_collection(value: T) -> t.Collection[T]:
+ ...
def ensure_collection(value):
@@ -232,7 +236,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
for node, deps in tuple(dag.items()):
for dep in deps:
- if not dep in dag:
+ if dep not in dag:
dag[dep] = set()
while dag:
@@ -316,6 +320,14 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
return new
+def is_int(text: str) -> bool:
+ try:
+ int(text)
+ return True
+ except ValueError:
+ return False
+
+
def name_sequence(prefix: str) -> t.Callable[[], str]:
"""Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
sequence = count()
diff --git a/sqlglot/jsonpath.py b/sqlglot/jsonpath.py
index c410d11..129a4e6 100644
--- a/sqlglot/jsonpath.py
+++ b/sqlglot/jsonpath.py
@@ -2,8 +2,8 @@ from __future__ import annotations
import typing as t
+import sqlglot.expressions as exp
from sqlglot.errors import ParseError
-from sqlglot.expressions import SAFE_IDENTIFIER_RE
from sqlglot.tokens import Token, Tokenizer, TokenType
if t.TYPE_CHECKING:
@@ -36,20 +36,8 @@ class JSONPathTokenizer(Tokenizer):
STRING_ESCAPES = ["\\"]
-JSONPathNode = t.Dict[str, t.Any]
-
-
-def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode:
- node = {"kind": kind, **kwargs}
-
- if value is not None:
- node["value"] = value
-
- return node
-
-
-def parse(path: str) -> t.List[JSONPathNode]:
- """Takes in a JSONPath string and converts into a list of nodes."""
+def parse(path: str) -> exp.JSONPath:
+ """Takes in a JSON path string and parses it into a JSONPath expression."""
tokens = JSONPathTokenizer().tokenize(path)
size = len(tokens)
@@ -89,7 +77,7 @@ def parse(path: str) -> t.List[JSONPathNode]:
if token:
return token.text
if _match(TokenType.STAR):
- return _node("wildcard")
+ return exp.JSONPathWildcard()
if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN):
script = _prev().text == "("
start = i
@@ -100,9 +88,9 @@ def parse(path: str) -> t.List[JSONPathNode]:
if _curr() in (TokenType.R_BRACKET, None):
break
_advance()
- return _node(
- "script" if script else "filter", path[tokens[start].start : tokens[i].end]
- )
+
+ expr_type = exp.JSONPathScript if script else exp.JSONPathFilter
+ return expr_type(this=path[tokens[start].start : tokens[i].end])
number = "-" if _match(TokenType.DASH) else ""
@@ -112,6 +100,7 @@ def parse(path: str) -> t.List[JSONPathNode]:
if number:
return int(number)
+
return False
def _parse_slice() -> t.Any:
@@ -121,9 +110,10 @@ def parse(path: str) -> t.List[JSONPathNode]:
if end is None and step is None:
return start
- return _node("slice", start=start, end=end, step=step)
- def _parse_bracket() -> JSONPathNode:
+ return exp.JSONPathSlice(start=start, end=end, step=step)
+
+ def _parse_bracket() -> exp.JSONPathPart:
literal = _parse_slice()
if isinstance(literal, str) or literal is not False:
@@ -136,13 +126,15 @@ def parse(path: str) -> t.List[JSONPathNode]:
if len(indexes) == 1:
if isinstance(literal, str):
- node = _node("key", indexes[0])
- elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"):
- node = _node("selector", indexes[0])
+ node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0])
+ elif isinstance(literal, exp.JSONPathPart) and isinstance(
+ literal, (exp.JSONPathScript, exp.JSONPathFilter)
+ ):
+ node = exp.JSONPathSelector(this=indexes[0])
else:
- node = _node("subscript", indexes[0])
+ node = exp.JSONPathSubscript(this=indexes[0])
else:
- node = _node("union", indexes)
+ node = exp.JSONPathUnion(expressions=indexes)
else:
raise ParseError(_error("Cannot have empty segment"))
@@ -150,66 +142,56 @@ def parse(path: str) -> t.List[JSONPathNode]:
return node
- nodes = []
+ # We canonicalize the JSON path AST so that it always starts with a
+ # "root" element, so paths like "field" will be generated as "$.field"
+ _match(TokenType.DOLLAR)
+ expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
while _curr():
- if _match(TokenType.DOLLAR):
- nodes.append(_node("root"))
- elif _match(TokenType.DOT):
+ if _match(TokenType.DOT) or _match(TokenType.COLON):
recursive = _prev().text == ".."
- value = _match(TokenType.VAR) or _match(TokenType.STAR)
- nodes.append(
- _node("recursive" if recursive else "child", value=value.text if value else None)
- )
+
+ if _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
+ value: t.Optional[str | exp.JSONPathWildcard] = _prev().text
+ elif _match(TokenType.STAR):
+ value = exp.JSONPathWildcard()
+ else:
+ value = None
+
+ if recursive:
+ expressions.append(exp.JSONPathRecursive(this=value))
+ elif value:
+ expressions.append(exp.JSONPathKey(this=value))
+ else:
+ raise ParseError(_error("Expected key name or * after DOT"))
elif _match(TokenType.L_BRACKET):
- nodes.append(_parse_bracket())
- elif _match(TokenType.VAR):
- nodes.append(_node("key", _prev().text))
+ expressions.append(_parse_bracket())
+ elif _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
+ expressions.append(exp.JSONPathKey(this=_prev().text))
elif _match(TokenType.STAR):
- nodes.append(_node("wildcard"))
- elif _match(TokenType.PARAMETER):
- nodes.append(_node("current"))
+ expressions.append(exp.JSONPathWildcard())
else:
raise ParseError(_error(f"Unexpected {tokens[i].token_type}"))
- return nodes
+ return exp.JSONPath(expressions=expressions)
-MAPPING = {
- "child": lambda n: f".{n['value']}" if n.get("value") is not None else "",
- "filter": lambda n: f"?{n['value']}",
- "key": lambda n: (
- f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]'
- ),
- "recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..",
- "root": lambda _: "$",
- "script": lambda n: f"({n['value']}",
- "slice": lambda n: ":".join(
- "" if p is False else generate([p])
- for p in [n["start"], n["end"], n["step"]]
+JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
+ exp.JSONPathFilter: lambda _, e: f"?{e.this}",
+ exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e),
+ exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}",
+ exp.JSONPathRoot: lambda *_: "$",
+ exp.JSONPathScript: lambda _, e: f"({e.this}",
+ exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]",
+ exp.JSONPathSlice: lambda self, e: ":".join(
+ "" if p is False else self.json_path_part(p)
+ for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")]
if p is not None
),
- "selector": lambda n: f"[{generate([n['value']])}]",
- "subscript": lambda n: f"[{generate([n['value']])}]",
- "union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]",
- "wildcard": lambda _: "*",
+ exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e),
+ exp.JSONPathUnion: lambda self,
+ e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]",
+ exp.JSONPathWildcard: lambda *_: "*",
}
-
-def generate(
- nodes: t.List[JSONPathNode],
- mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None,
-) -> str:
- mapping = MAPPING if mapping is None else mapping
- path = []
-
- for node in nodes:
- if isinstance(node, dict):
- path.append(mapping[node["kind"]](node))
- elif isinstance(node, str):
- escaped = node.replace('"', '\\"')
- path.append(f'"{escaped}"')
- else:
- path.append(str(node))
-
- return "".join(path)
+ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS)
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
index ee48006..34ea6cb 100644
--- a/sqlglot/optimizer/__init__.py
+++ b/sqlglot/optimizer/__init__.py
@@ -1,3 +1,5 @@
+# ruff: noqa: F401
+
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.scope import (
Scope,
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index f2a0990..d22a998 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -10,11 +10,13 @@ if t.TYPE_CHECKING:
@t.overload
-def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
+ ...
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
+ ...
def normalize_identifiers(expression, dialect=None):
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index e3aaebc..53490bf 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -8,10 +8,10 @@ from sqlglot.schema import ensure_schema
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
+
# Selection to use if selection list is empty
-DEFAULT_SELECTION = lambda is_agg: alias(
- exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
-)
+def default_selection(is_agg: bool) -> exp.Alias:
+ return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
@@ -129,7 +129,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION(is_agg))
+ new_selections.append(default_selection(is_agg))
scope.expression.select(*new_selections, append=False, copy=False)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d5b9119..90357dd 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -104,7 +104,6 @@ def simplify(
if root:
expression.replace(node)
-
return node
expression = while_changing(expression, _simplify)
@@ -174,16 +173,20 @@ def simplify_not(expression):
if isinstance(this, exp.Paren):
condition = this.unnest()
if isinstance(condition, exp.And):
- return exp.or_(
- exp.not_(condition.left, copy=False),
- exp.not_(condition.right, copy=False),
- copy=False,
+ return exp.paren(
+ exp.or_(
+ exp.not_(condition.left, copy=False),
+ exp.not_(condition.right, copy=False),
+ copy=False,
+ )
)
if isinstance(condition, exp.Or):
- return exp.and_(
- exp.not_(condition.left, copy=False),
- exp.not_(condition.right, copy=False),
- copy=False,
+ return exp.paren(
+ exp.and_(
+ exp.not_(condition.left, copy=False),
+ exp.not_(condition.right, copy=False),
+ copy=False,
+ )
)
if is_null(condition):
return exp.null()
@@ -490,7 +493,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
- if not l.__class__ in INVERSE_OPS:
+ if l.__class__ not in INVERSE_OPS:
return expression
if r.is_number:
@@ -714,8 +717,7 @@ def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
if not isinstance(expression, CONCATS) or (
# We can't reduce a CONCAT_WS call if we don't statically know the separator
- isinstance(expression, exp.ConcatWs)
- and not expression.expressions[0].is_string
+ isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
):
return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index c091605..a89e4fa 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -60,6 +60,19 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
+def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
+ def _parser(args: t.List, dialect: Dialect) -> E:
+ expression = expr_type(
+ this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
+ )
+ if len(args) > 2 and expr_type is exp.JSONExtract:
+ expression.set("expressions", args[2:])
+
+ return expression
+
+ return _parser
+
+
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@@ -102,6 +115,9 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
+ "JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract),
+ "JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar),
+ "JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar),
"LIKE": parse_like,
"LOG": parse_logarithm,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
@@ -175,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.NCHAR,
TokenType.VARCHAR,
TokenType.NVARCHAR,
+ TokenType.BPCHAR,
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
@@ -295,6 +312,7 @@ class Parser(metaclass=_Parser):
TokenType.ASC,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
+ TokenType.BPCHAR,
TokenType.CACHE,
TokenType.CASE,
TokenType.COLLATE,
@@ -531,12 +549,12 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
- expression=path,
+ expression=self.dialect.to_json_path(path),
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
- expression=path,
+ expression=self.dialect.to_json_path(path),
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
@@ -1334,7 +1352,9 @@ class Parser(metaclass=_Parser):
exp.Drop,
comments=start.comments,
exists=exists or self._parse_exists(),
- this=self._parse_table(schema=True),
+ this=self._parse_table(
+ schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
+ ),
kind=kind,
temporary=temporary,
materialized=materialized,
@@ -1422,7 +1442,9 @@ class Parser(metaclass=_Parser):
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
- table_parts = self._parse_table_parts(schema=True)
+ table_parts = self._parse_table_parts(
+ schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA
+ )
# exp.Properties.Location.POST_NAME
self._match(TokenType.COMMA)
@@ -2499,11 +2521,11 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
text = "ALL ROWS PER MATCH"
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
- text += f" SHOW EMPTY MATCHES"
+ text += " SHOW EMPTY MATCHES"
elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"):
- text += f" OMIT EMPTY MATCHES"
+ text += " OMIT EMPTY MATCHES"
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
- text += f" WITH UNMATCHED ROWS"
+ text += " WITH UNMATCHED ROWS"
rows = exp.var(text)
else:
rows = None
@@ -2511,9 +2533,9 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("AFTER", "MATCH", "SKIP"):
text = "AFTER MATCH SKIP"
if self._match_text_seq("PAST", "LAST", "ROW"):
- text += f" PAST LAST ROW"
+ text += " PAST LAST ROW"
elif self._match_text_seq("TO", "NEXT", "ROW"):
- text += f" TO NEXT ROW"
+ text += " TO NEXT ROW"
elif self._match_text_seq("TO", "FIRST"):
text += f" TO FIRST {self._advance_any().text}" # type: ignore
elif self._match_text_seq("TO", "LAST"):
@@ -2772,7 +2794,7 @@ class Parser(metaclass=_Parser):
or self._parse_placeholder()
)
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
catalog = None
db = None
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
@@ -2788,8 +2810,15 @@ class Parser(metaclass=_Parser):
db = table
table = self._parse_table_part(schema=schema) or ""
- if not table:
+ if is_db_reference:
+ catalog = db
+ db = table
+ table = None
+
+ if not table and not is_db_reference:
self.raise_error(f"Expected table name but got {self._curr}")
+ if not db and is_db_reference:
+ self.raise_error(f"Expected database name but got {self._curr}")
return self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@@ -2801,6 +2830,7 @@ class Parser(metaclass=_Parser):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
+ is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@@ -2823,7 +2853,11 @@ class Parser(metaclass=_Parser):
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
this = t.cast(
- exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
+ exp.Expression,
+ bracket
+ or self._parse_bracket(
+ self._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
+ ),
)
if schema:
@@ -3650,7 +3684,6 @@ class Parser(metaclass=_Parser):
identifier = allow_identifiers and self._parse_id_var(
any_token=False, tokens=(TokenType.VAR,)
)
-
if identifier:
tokens = self.dialect.tokenize(identifier.name)
@@ -3818,12 +3851,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self) -> t.Optional[exp.Expression]:
+ this = self._parse_column_reference()
+ return self._parse_column_ops(this) if this else self._parse_bracket(this)
+
+ def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
if isinstance(this, exp.Identifier):
this = self.expression(exp.Column, this=this)
- elif not this:
- return self._parse_bracket(this)
- return self._parse_column_ops(this)
+ return this
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@@ -3837,13 +3872,7 @@ class Parser(metaclass=_Parser):
if not field:
self.raise_error("Expected type")
elif op and self._curr:
- self._advance()
- value = self._prev.text
- field = (
- exp.Literal.number(value)
- if self._prev.token_type == TokenType.NUMBER
- else exp.Literal.string(value)
- )
+ field = self._parse_column_reference()
else:
field = self._parse_field(anonymous_func=True, any_token=True)
@@ -4375,7 +4404,10 @@ class Parser(metaclass=_Parser):
options[kind] = action
return self.expression(
- exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
+ exp.ForeignKey,
+ expressions=expressions,
+ reference=reference,
+ **options, # type: ignore
)
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
@@ -4692,10 +4724,12 @@ class Parser(metaclass=_Parser):
return None
@t.overload
- def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
+ ...
@t.overload
- def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
+ ...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@@ -4937,6 +4971,13 @@ class Parser(metaclass=_Parser):
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
# and Snowflake chose to do the same for familiarity
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
+ if isinstance(this, exp.AggFunc):
+ ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls)
+
+ if ignore_respect and ignore_respect is not this:
+ ignore_respect.replace(ignore_respect.this)
+ this = self.expression(ignore_respect.__class__, this=this)
+
this = self._parse_respect_or_ignore_nulls(this)
# bigquery select from window x AS (partition by ...)
@@ -5732,12 +5773,14 @@ class Parser(metaclass=_Parser):
return True
@t.overload
- def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
+ ...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]: ...
+ ) -> t.Optional[exp.Expression]:
+ ...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 8a363d2..87a4924 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -125,6 +125,7 @@ class TokenType(AutoName):
NCHAR = auto()
VARCHAR = auto()
NVARCHAR = auto()
+ BPCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
@@ -801,6 +802,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
+ "BPCHAR": TokenType.BPCHAR,
"STR": TokenType.TEXT,
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
@@ -1141,7 +1143,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
- while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
+ while not self._end and self.WHITE_SPACE.get(self._peek) is not TokenType.BREAK:
self._advance(alnum=True)
self._comments.append(self._text[comment_start_size:])
@@ -1259,7 +1261,7 @@ class Tokenizer(metaclass=_Tokenizer):
if base:
try:
int(text, base)
- except:
+ except Exception:
raise TokenError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 0da65b5..f13569f 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -485,8 +485,8 @@ def preprocess(
expression_type = type(expression)
expression = transforms[0](expression)
- for t in transforms[1:]:
- expression = t(expression)
+ for transform in transforms[1:]:
+ expression = transform(expression)
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler: