summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-12-19 11:01:55 +0000
commitf1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5 (patch)
tree5dce0fe2a11381761496eb973c20750f44db56d5 /sqlglot
parentReleasing debian version 20.1.0-1. (diff)
downloadsqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.tar.xz
sqlglot-f1c2dbe3b17a0d5edffbb65b85b642d0bb2756c5.zip
Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/__init__.py20
-rw-r--r--sqlglot/dialects/bigquery.py7
-rw-r--r--sqlglot/dialects/clickhouse.py5
-rw-r--r--sqlglot/dialects/dialect.py101
-rw-r--r--sqlglot/dialects/drill.py1
-rw-r--r--sqlglot/dialects/duckdb.py37
-rw-r--r--sqlglot/dialects/hive.py1
-rw-r--r--sqlglot/dialects/mysql.py2
-rw-r--r--sqlglot/dialects/postgres.py1
-rw-r--r--sqlglot/dialects/presto.py41
-rw-r--r--sqlglot/dialects/snowflake.py80
-rw-r--r--sqlglot/dialects/teradata.py24
-rw-r--r--sqlglot/dialects/tsql.py7
-rw-r--r--sqlglot/executor/python.py3
-rw-r--r--sqlglot/expressions.py46
-rw-r--r--sqlglot/generator.py67
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py33
-rw-r--r--sqlglot/optimizer/merge_subqueries.py17
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py13
-rw-r--r--sqlglot/optimizer/scope.py39
-rw-r--r--sqlglot/optimizer/simplify.py153
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
-rw-r--r--sqlglot/parser.py129
-rw-r--r--sqlglot/planner.py11
-rw-r--r--sqlglot/tokens.py91
25 files changed, 642 insertions, 289 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 8212669..04990ac 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -12,7 +12,7 @@ classes as needed.
### Implementing a custom Dialect
-Consider the following example:
+Creating a new SQL dialect may seem complicated at first, but it is actually quite simple in SQLGlot:
```python
from sqlglot import exp
@@ -23,9 +23,10 @@ from sqlglot.tokens import Tokenizer, TokenType
class Custom(Dialect):
class Tokenizer(Tokenizer):
- QUOTES = ["'", '"']
- IDENTIFIERS = ["`"]
+ QUOTES = ["'", '"'] # Strings can be delimited by either single or double quotes
+ IDENTIFIERS = ["`"] # Identifiers can be delimited by backticks
+ # Associates certain meaningful words with tokens that capture their intent
KEYWORDS = {
**Tokenizer.KEYWORDS,
"INT64": TokenType.BIGINT,
@@ -33,8 +34,12 @@ class Custom(Dialect):
}
class Generator(Generator):
- TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
+ # Specifies how AST nodes, i.e. subclasses of exp.Expression, should be converted into SQL
+ TRANSFORMS = {
+ exp.Array: lambda self, e: f"[{self.expressions(e)}]",
+ }
+ # Specifies how AST nodes representing data types should be converted into SQL
TYPE_MAPPING = {
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
@@ -48,10 +53,9 @@ class Custom(Dialect):
}
```
-This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
-delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
-the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
-logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
+The above example demonstrates how certain parts of the base `Dialect` class can be overridden to match a different
+specification. Even though it is a fairly realistic starting point, we strongly encourage the reader to study existing
+dialect implementations in order to understand how their various components can be modified, depending on the use-case.
----
"""
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 2a9dde9..1b06cbf 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -215,6 +215,7 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
class BigQuery(Dialect):
+ WEEK_OFFSET = -1
UNNEST_COLUMN_ONLY = True
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
@@ -437,11 +438,7 @@ class BigQuery(Dialect):
elif isinstance(this, exp.Literal):
table_name = this.name
- if (
- self._curr
- and self._prev.end == self._curr.start - 1
- and self._parse_var(any_token=True)
- ):
+ if self._is_connected() and self._parse_var(any_token=True):
table_name += self._prev.text
this = exp.Identifier(this=table_name, quoted=True)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index da182aa..7a3f897 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -83,6 +83,11 @@ class ClickHouse(Dialect):
}
class Parser(parser.Parser):
+ # Tested in ClickHouse's playground, it seems that the following two queries do the same thing
+ # * select x from t1 union all select x from t2 limit 1;
+ # * select x from t1 union all (select x from t2 limit 1);
+ MODIFIERS_ATTACHED_TO_UNION = False
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index c7cea64..b7eef45 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -21,11 +21,14 @@ DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
class Dialects(str, Enum):
+ """Dialects supported by SQLGLot."""
+
DIALECT = ""
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
+ DORIS = "doris"
DRILL = "drill"
DUCKDB = "duckdb"
HIVE = "hive"
@@ -43,16 +46,22 @@ class Dialects(str, Enum):
TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
- Doris = "doris"
class NormalizationStrategy(str, AutoName):
"""Specifies the strategy according to which identifiers should be normalized."""
- LOWERCASE = auto() # Unquoted identifiers are lowercased
- UPPERCASE = auto() # Unquoted identifiers are uppercased
- CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
- CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
+ LOWERCASE = auto()
+ """Unquoted identifiers are lowercased."""
+
+ UPPERCASE = auto()
+ """Unquoted identifiers are uppercased."""
+
+ CASE_SENSITIVE = auto()
+ """Always case-sensitive, regardless of quotes."""
+
+ CASE_INSENSITIVE = auto()
+ """Always case-insensitive, regardless of quotes."""
class _Dialect(type):
@@ -117,6 +126,7 @@ class _Dialect(type):
klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
+ klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
@@ -131,74 +141,84 @@ class _Dialect(type):
class Dialect(metaclass=_Dialect):
- # Determines the base index offset for arrays
INDEX_OFFSET = 0
+ """Determines the base index offset for arrays."""
+
+ WEEK_OFFSET = 0
+ """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
- # If true unnest table aliases are considered only as column aliases
UNNEST_COLUMN_ONLY = False
+ """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
- # Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
+ """Determines whether or not the table alias comes after tablesample."""
- # Specifies the strategy according to which identifiers should be normalized.
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
+ """Specifies the strategy according to which identifiers should be normalized."""
- # Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
+ """Determines whether or not an unquoted identifier can start with a digit."""
- # Determines whether or not the DPIPE token ('||') is a string concatenation operator
DPIPE_IS_STRING_CONCAT = True
+ """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
- # Determines whether or not CONCAT's arguments must be strings
STRICT_STRING_CONCAT = False
+ """Determines whether or not `CONCAT`'s arguments must be strings."""
- # Determines whether or not user-defined data types are supported
SUPPORTS_USER_DEFINED_TYPES = True
+ """Determines whether or not user-defined data types are supported."""
- # Determines whether or not SEMI/ANTI JOINs are supported
SUPPORTS_SEMI_ANTI_JOIN = True
+ """Determines whether or not `SEMI` or `ANTI` joins are supported."""
- # Determines how function names are going to be normalized
NORMALIZE_FUNCTIONS: bool | str = "upper"
+ """Determines how function names are going to be normalized."""
- # Determines whether the base comes first in the LOG function
LOG_BASE_FIRST = True
+ """Determines whether the base comes first in the `LOG` function."""
- # Indicates the default null ordering method to use if not explicitly set
- # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
NULL_ORDERING = "nulls_are_small"
+ """
+ Indicates the default `NULL` ordering method to use if not explicitly set.
+ Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
+ """
- # Whether the behavior of a / b depends on the types of a and b.
- # False means a / b is always float division.
- # True means a / b is integer division if both a and b are integers.
TYPED_DIVISION = False
+ """
+ Whether the behavior of `a / b` depends on the types of `a` and `b`.
+ False means `a / b` is always float division.
+ True means `a / b` is integer division if both `a` and `b` are integers.
+ """
- # False means 1 / 0 throws an error.
- # True means 1 / 0 returns null.
SAFE_DIVISION = False
+ """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
- # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
CONCAT_COALESCE = False
+ """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
- # Custom time mappings in which the key represents dialect time format
- # and the value represents a python time format
TIME_MAPPING: t.Dict[str, str] = {}
+ """Associates this dialect's time formats with their equivalent Python `strftime` format."""
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
- # special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
+ """
+ Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
+ If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
+ """
- # Mapping of an unescaped escape sequence to the corresponding character
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+ """Mapping of an unescaped escape sequence to the corresponding character."""
- # Columns that are auto-generated by the engine corresponding to this dialect
- # Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
+ """
+ Columns that are auto-generated by the engine corresponding to this dialect.
+ For example, such columns may be excluded from `SELECT *` queries.
+ """
# --- Autofilled ---
@@ -221,13 +241,15 @@ class Dialect(metaclass=_Dialect):
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
- # Delimiters for bit, hex and byte literals
+ # Delimiters for bit, hex, byte and unicode literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
+ UNICODE_START: t.Optional[str] = None
+ UNICODE_END: t.Optional[str] = None
@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
@@ -275,6 +297,7 @@ class Dialect(metaclass=_Dialect):
def format_time(
cls, expression: t.Optional[str | exp.Expression]
) -> t.Optional[exp.Expression]:
+ """Converts a time format in this dialect to its equivalent Python `strftime` format."""
if isinstance(expression, str):
return exp.Literal.string(
# the time formats are quoted
@@ -306,9 +329,9 @@ class Dialect(metaclass=_Dialect):
"""
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
- For example, an identifier like FoO would be resolved as foo in Postgres, because it
+ For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
- it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
+ it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are
@@ -356,8 +379,8 @@ class Dialect(metaclass=_Dialect):
Args:
text: The text to check.
identify:
- "always" or `True`: Always returns true.
- "safe": True if the identifier is case-insensitive.
+ `"always"` or `True`: Always returns `True`.
+ `"safe"`: Only returns `True` if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
@@ -371,6 +394,14 @@ class Dialect(metaclass=_Dialect):
return False
def quote_identifier(self, expression: E, identify: bool = True) -> E:
+ """
+ Adds quotes to a given identifier.
+
+ Args:
+ expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
+ identify: If set to `False`, the quotes will only be added if the identifier is deemed
+ "unsafe", with respect to its characters and this dialect's normalization strategy.
+ """
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 70c96f8..c9b31a0 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -81,7 +81,6 @@ class Drill(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
- ENCODE = "utf-8"
class Parser(parser.Parser):
STRICT_CAST = False
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index b94e3a6..41afad8 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -84,11 +84,35 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
+def _parse_make_timestamp(args: t.List) -> exp.Expression:
+ if len(args) == 1:
+ return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
+
+ return exp.TimestampFromParts(
+ year=seq_get(args, 0),
+ month=seq_get(args, 1),
+ day=seq_get(args, 2),
+ hour=seq_get(args, 3),
+ min=seq_get(args, 4),
+ sec=seq_get(args, 5),
+ )
+
+
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
- args = [
- f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}"
- for e in expression.expressions
- ]
+ args: t.List[str] = []
+ for expr in expression.expressions:
+ if isinstance(expr, exp.Alias):
+ key = expr.alias
+ value = expr.this
+ else:
+ key = expr.name or expr.this.name
+ if isinstance(expr, exp.Bracket):
+ value = expr.expressions[0]
+ else:
+ value = expr.expression
+
+ args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}")
+
return f"{{{', '.join(args)}}}"
@@ -189,9 +213,7 @@ class DuckDB(Dialect):
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
- "MAKE_TIMESTAMP": lambda args: exp.UnixToTime(
- this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
- ),
+ "MAKE_TIMESTAMP": _parse_make_timestamp,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
@@ -339,6 +361,7 @@ class DuckDB(Dialect):
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
+ exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 0723e37..65c85bb 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -240,7 +240,6 @@ class Hive(Dialect):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
- ENCODE = "utf-8"
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index cfc6e83..5fe3d82 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -650,7 +650,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
- exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index fefddee..bf65edf 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -277,6 +277,7 @@ class Postgres(Dialect):
"CONSTRAINT TRIGGER": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
+ "EXEC": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"JSONB": TokenType.JSONB,
"MONEY": TokenType.MONEY,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 10a6074..360ab65 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -186,6 +186,27 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
return ""
+def _to_int(expression: exp.Expression) -> exp.Expression:
+ if not expression.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ annotate_types(expression)
+ if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES:
+ return exp.cast(expression, to=exp.DataType.Type.BIGINT)
+ return expression
+
+
+def _parse_to_char(args: t.List) -> exp.TimeToStr:
+ fmt = seq_get(args, 1)
+ if isinstance(fmt, exp.Literal):
+ # We uppercase this to match Teradata's format mapping keys
+ fmt.set("this", fmt.this.upper())
+
+ # We use "teradata" on purpose here, because the time formats are different in Presto.
+ # See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char
+ return format_time_lambda(exp.TimeToStr, "teradata")(args)
+
+
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@@ -201,6 +222,12 @@ class Presto(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
class Tokenizer(tokens.Tokenizer):
+ UNICODE_STRINGS = [
+ (prefix + q, q)
+ for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
+ for prefix in ("U&", "u&")
+ ]
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
@@ -253,8 +280,9 @@ class Presto(Dialect):
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
- "TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
+ "TO_CHAR": _parse_to_char,
"TO_HEX": exp.Hex.from_arg_list,
+ "TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
@@ -315,7 +343,12 @@ class Presto(Dialect):
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
- "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ "DATE_ADD",
+ exp.Literal.string(e.text("unit") or "day"),
+ _to_int(
+ e.expression,
+ ),
+ e.this,
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
@@ -325,7 +358,7 @@ class Presto(Dialect):
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "day"),
- e.expression * -1,
+ _to_int(e.expression * -1),
e.this,
),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
@@ -354,6 +387,7 @@ class Presto(Dialect):
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
+ exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.Select: transforms.preprocess(
[
transforms.eliminate_qualify,
@@ -377,6 +411,7 @@ class Presto(Dialect):
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
+ exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index cdbc071..f09a990 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -293,7 +293,6 @@ class Snowflake(Dialect):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
- "TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
@@ -369,36 +368,58 @@ class Snowflake(Dialect):
return lateral
+ def _parse_at_before(self, table: exp.Table) -> exp.Table:
+ # https://docs.snowflake.com/en/sql-reference/constructs/at-before
+ index = self._index
+ if self._match_texts(("AT", "BEFORE")):
+ this = self._prev.text.upper()
+ kind = (
+ self._match(TokenType.L_PAREN)
+ and self._match_texts(self.HISTORICAL_DATA_KIND)
+ and self._prev.text.upper()
+ )
+ expression = self._match(TokenType.FARROW) and self._parse_bitwise()
+
+ if expression:
+ self._match_r_paren()
+ when = self.expression(
+ exp.HistoricalData, this=this, kind=kind, expression=expression
+ )
+ table.set("when", when)
+ else:
+ self._retreat(index)
+
+ return table
+
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
- table: t.Optional[exp.Expression] = None
- if self._match_text_seq("@"):
- table_name = "@"
- while self._curr:
- self._advance()
- table_name += self._prev.text
- if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
- break
- while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
- table_name += self._prev.text
-
- table = exp.var(table_name)
- elif self._match(TokenType.STRING, advance=False):
+ if self._match(TokenType.STRING, advance=False):
table = self._parse_string()
+ elif self._match_text_seq("@", advance=False):
+ table = self._parse_location_path()
+ else:
+ table = None
if table:
file_format = None
pattern = None
- if self._match_text_seq("(", "FILE_FORMAT", "=>"):
- file_format = self._parse_string() or super()._parse_table_parts()
- if self._match_text_seq(",", "PATTERN", "=>"):
+ self._match(TokenType.L_PAREN)
+ while self._curr and not self._match(TokenType.R_PAREN):
+ if self._match_text_seq("FILE_FORMAT", "=>"):
+ file_format = self._parse_string() or super()._parse_table_parts()
+ elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
- self._match_r_paren()
+ else:
+ break
+
+ self._match(TokenType.COMMA)
- return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
+ table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
+ else:
+ table = super()._parse_table_parts(schema=schema)
- return super()._parse_table_parts(schema=schema)
+ return self._parse_at_before(table)
def _parse_id_var(
self,
@@ -438,17 +459,17 @@ class Snowflake(Dialect):
def _parse_location(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
+ return self.expression(exp.LocationProperty, this=self._parse_location_path())
- parts = [self._parse_var(any_token=True)]
+ def _parse_location_path(self) -> exp.Var:
+ parts = [self._advance_any(ignore_reserved=True)]
- while self._match(TokenType.SLASH):
- if self._curr and self._prev.end + 1 == self._curr.start:
- parts.append(self._parse_var(any_token=True))
- else:
- parts.append(exp.Var(this=""))
- return self.expression(
- exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
- )
+ # We avoid consuming a comma token because external tables like @foo and @bar
+ # can be joined in a query with a comma separator.
+ while self._is_connected() and not self._match(TokenType.COMMA, advance=False):
+ parts.append(self._advance_any(ignore_reserved=True))
+
+ return exp.var("".join(part.text for part in parts if part))
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
@@ -562,6 +583,7 @@ class Snowflake(Dialect):
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
),
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
+ exp.ToArray: rename_func("TO_ARRAY"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 141d9c0..0ccc567 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -12,22 +12,30 @@ class Teradata(Dialect):
TYPED_DIVISION = True
TIME_MAPPING = {
- "Y": "%Y",
- "YYYY": "%Y",
"YY": "%y",
- "MMMM": "%B",
+ "Y4": "%Y",
+ "YYYY": "%Y",
+ "M4": "%B",
+ "M3": "%b",
+ "M": "%-M",
+ "MI": "%M",
+ "MM": "%m",
"MMM": "%b",
- "DD": "%d",
+ "MMMM": "%B",
"D": "%-d",
- "HH": "%H",
+ "DD": "%d",
+ "D3": "%j",
+ "DDD": "%j",
"H": "%-H",
- "MM": "%M",
- "M": "%-M",
- "SS": "%S",
+ "HH": "%H",
+ "HH24": "%H",
"S": "%-S",
+ "SS": "%S",
"SSSSSS": "%f",
"E": "%a",
"EE": "%a",
+ "E3": "%a",
+ "E4": "%A",
"EEE": "%a",
"EEEE": "%A",
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index c3d4f0a..165a703 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -701,6 +701,13 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def set_operation(self, expression: exp.Union, op: str) -> str:
+ limit = expression.args.get("limit")
+ if limit:
+ return self.sql(expression.limit(limit.pop(), copy=False))
+
+ return super().set_operation(expression, op)
+
def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index e1e597d..3277e65 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -343,6 +343,9 @@ class PythonExecutor:
else:
sink.rows = left.rows + right.rows
+ if not math.isinf(step.limit):
+ sink.rows = sink.rows[0 : step.limit]
+
return self.context({step.name: sink})
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 99722be..8246769 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1105,14 +1105,7 @@ class Create(DDL):
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
class Clone(Expression):
- arg_types = {
- "this": True,
- "when": False,
- "kind": False,
- "shallow": False,
- "expression": False,
- "copy": False,
- }
+ arg_types = {"this": True, "shallow": False, "copy": False}
class Describe(Expression):
@@ -1213,6 +1206,10 @@ class RawString(Condition):
pass
+class UnicodeString(Condition):
+ arg_types = {"this": True, "escape": False}
+
+
class Column(Condition):
arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False}
@@ -1967,7 +1964,12 @@ class Offset(Expression):
class Order(Expression):
- arg_types = {"this": False, "expressions": True}
+ arg_types = {"this": False, "expressions": True, "interpolate": False}
+
+
+# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
+class WithFill(Expression):
+ arg_types = {"from": False, "to": False, "step": False}
# hive specific sorts
@@ -1985,7 +1987,7 @@ class Sort(Order):
class Ordered(Expression):
- arg_types = {"this": True, "desc": False, "nulls_first": True}
+ arg_types = {"this": True, "desc": False, "nulls_first": True, "with_fill": False}
class Property(Expression):
@@ -2522,6 +2524,11 @@ class IndexTableHint(Expression):
arg_types = {"this": True, "expressions": False, "target": False}
+# https://docs.snowflake.com/en/sql-reference/constructs/at-before
+class HistoricalData(Expression):
+ arg_types = {"this": True, "kind": True, "expression": True}
+
+
class Table(Expression):
arg_types = {
"this": True,
@@ -2538,6 +2545,7 @@ class Table(Expression):
"pattern": False,
"index": False,
"ordinality": False,
+ "when": False,
}
@property
@@ -4310,6 +4318,11 @@ class Array(Func):
is_var_len_args = True
+# https://docs.snowflake.com/en/sql-reference/functions/to_array
+class ToArray(Func):
+ pass
+
+
# https://docs.snowflake.com/en/sql-reference/functions/to_char
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
class ToChar(Func):
@@ -5233,6 +5246,19 @@ class UnixToTimeStr(Func):
pass
+class TimestampFromParts(Func):
+ """Constructs a timestamp given its constituent parts."""
+
+ arg_types = {
+ "year": True,
+ "month": True,
+ "day": True,
+ "hour": True,
+ "min": True,
+ "sec": True,
+ }
+
+
class Upper(Func):
_sql_names = ["UPPER", "UCASE"]
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index f3f9060..c571e8f 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -862,15 +862,7 @@ class Generator:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE"
- this = f"{shallow}{keyword} {this}"
- when = self.sql(expression, "when")
-
- if when:
- kind = self.sql(expression, "kind")
- expr = self.sql(expression, "expression")
- return f"{this} {when} ({kind} => {expr})"
-
- return this
+ return f"{shallow}{keyword} {this}"
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
@@ -923,6 +915,14 @@ class Generator:
return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}"
return this
+ def unicodestring_sql(self, expression: exp.UnicodeString) -> str:
+ this = self.sql(expression, "this")
+ if self.dialect.UNICODE_START:
+ escape = self.sql(expression, "escape")
+ escape = f" UESCAPE {escape}" if escape else ""
+ return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
+ return this
+
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
@@ -1400,6 +1400,12 @@ class Generator:
target = f" FOR {target}" if target else ""
return f"{this}{target} ({self.expressions(expression, flat=True)})"
+ def historicaldata_sql(self, expression: exp.HistoricalData) -> str:
+ this = self.sql(expression, "this")
+ kind = self.sql(expression, "kind")
+ expr = self.sql(expression, "expression")
+ return f"{this} ({kind} => {expr})"
+
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
self.sql(part)
@@ -1436,6 +1442,10 @@ class Generator:
ordinality = f" WITH ORDINALITY{alias}"
alias = ""
+ when = self.sql(expression, "when")
+ if when:
+ table = f"{table} {when}"
+
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
@@ -1784,7 +1794,24 @@ class Generator:
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
- return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
+ order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
+ interpolated_values = [
+ f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
+ for named_expression in expression.args.get("interpolate") or []
+ ]
+ interpolate = (
+ f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else ""
+ )
+ return f"{order}{interpolate}"
+
+ def withfill_sql(self, expression: exp.WithFill) -> str:
+ from_sql = self.sql(expression, "from")
+ from_sql = f" FROM {from_sql}" if from_sql else ""
+ to_sql = self.sql(expression, "to")
+ to_sql = f" TO {to_sql}" if to_sql else ""
+ step_sql = self.sql(expression, "step")
+ step_sql = f" STEP {step_sql}" if step_sql else ""
+ return f"WITH FILL{from_sql}{to_sql}{step_sql}"
def cluster_sql(self, expression: exp.Cluster) -> str:
return self.op_expressions("CLUSTER BY", expression)
@@ -1826,7 +1853,10 @@ class Generator:
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
- return f"{this}{sort_order}{nulls_sort_change}"
+ with_fill = self.sql(expression, "with_fill")
+ with_fill = f" {with_fill}" if with_fill else ""
+
+ return f"{this}{sort_order}{nulls_sort_change}{with_fill}"
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
@@ -3048,11 +3078,24 @@ class Generator:
def operator_sql(self, expression: exp.Operator) -> str:
return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
+ def toarray_sql(self, expression: exp.ToArray) -> str:
+ arg = expression.this
+ if not arg.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ arg = annotate_types(arg)
+
+ if arg.is_type(exp.DataType.Type.ARRAY):
+ return self.sql(arg)
+
+ cond_for_null = arg.is_(exp.null())
+ return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
+
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
- expression = simplify(expression)
+ expression = simplify(expression, dialect=self.dialect)
return expression
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 1ab7768..1230cea 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -95,9 +95,6 @@ def eliminate_subqueries(expression):
def _eliminate(scope, existing_ctes, taken):
- if scope.is_union:
- return _eliminate_union(scope, existing_ctes, taken)
-
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)
@@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken):
return _eliminate_cte(scope, existing_ctes, taken)
-def _eliminate_union(scope, existing_ctes, taken):
- duplicate_cte_alias = existing_ctes.get(scope.expression)
-
- alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
-
- taken[alias] = scope
-
- # Try to maintain the selections
- expressions = scope.expression.selects
- selects = [
- exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
- for e in expressions
- if e.alias_or_name
- ]
- # If not all selections have an alias, just select *
- if len(selects) != len(expressions):
- selects = ["*"]
-
- scope.expression.replace(
- exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
- )
-
- if not duplicate_cte_alias:
- existing_ctes[scope.expression] = alias
- return exp.CTE(
- this=scope.expression,
- alias=exp.TableAlias(this=exp.to_identifier(alias)),
- )
-
-
def _eliminate_derived_table(scope, existing_ctes, taken):
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index a74bea7..ea148cc 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
for col in inner_projections[selection].find_all(exp.Column)
)
+ def _is_recursive():
+ # Recursive CTEs look like this:
+ # WITH RECURSIVE cte AS (
+ # SELECT * FROM x <-- inner scope
+ # UNION ALL
+ # SELECT * FROM cte <-- outer scope
+ # )
+ cte = inner_scope.expression.parent
+ node = outer_scope.expression.parent
+
+ while node:
+ if node is cte:
+ return True
+ node = node.parent
+ return False
+
return (
isinstance(outer_scope.expression, exp.Select)
and not outer_scope.expression.is_star
@@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
)
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
+ and not _is_recursive()
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index f7348b5..10ff13a 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
-def pushdown_predicates(expression):
+def pushdown_predicates(expression, dialect=None):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
@@ -36,7 +36,7 @@ def pushdown_predicates(expression):
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
- pushdown(where.this, selected_sources, scope_ref_count)
+ pushdown(where.this, selected_sources, scope_ref_count, dialect)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@@ -44,17 +44,20 @@ def pushdown_predicates(expression):
name = join.alias_or_name
if name in scope.selected_sources:
pushdown(
- join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
+ join.args.get("on"),
+ {name: scope.selected_sources[name]},
+ scope_ref_count,
+ dialect,
)
return expression
-def pushdown(condition, sources, scope_ref_count):
+def pushdown(condition, sources, scope_ref_count, dialect):
if not condition:
return
- condition = condition.replace(simplify(condition))
+ condition = condition.replace(simplify(condition, dialect=dialect))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index b7e527e..d34857d 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -37,6 +37,7 @@ class Scope:
For example:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
+ cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
@@ -61,11 +62,14 @@ class Scope:
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
+ cte_sources=None,
):
self.expression = expression
self.sources = sources or {}
- self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
+ self.lateral_sources = lateral_sources or {}
+ self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
+ self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
@@ -92,13 +96,17 @@ class Scope:
self._pivots = None
self._references = None
- def branch(self, expression, scope_type, chain_sources=None, **kwargs):
+ def branch(
+ self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
+ ):
"""Branch from the current scope to a new, inner scope"""
return Scope(
expression=expression.unnest(),
- sources={**self.cte_sources, **(chain_sources or {})},
+ sources=sources.copy() if sources else None,
parent=self,
scope_type=scope_type,
+ cte_sources={**self.cte_sources, **(cte_sources or {})},
+ lateral_sources=lateral_sources.copy() if lateral_sources else None,
**kwargs,
)
@@ -306,20 +314,6 @@ class Scope:
return self._references
@property
- def cte_sources(self):
- """
- Sources that are CTEs.
-
- Returns:
- dict[str, Scope]: Mapping of source alias to Scope
- """
- return {
- alias: scope
- for alias, scope in self.sources.items()
- if isinstance(scope, Scope) and scope.is_cte
- }
-
- @property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
@@ -515,7 +509,10 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
- yield from _traverse_subqueries(scope)
+ if scope.is_root:
+ yield from _traverse_select(scope)
+ else:
+ yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
@@ -572,7 +569,7 @@ def _traverse_ctes(scope):
for child_scope in _traverse_scope(
scope.branch(
cte.this,
- chain_sources=sources,
+ cte_sources=sources,
outer_column_list=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
@@ -584,12 +581,14 @@ def _traverse_ctes(scope):
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
+ child_scope.cte_sources[alias] = recursive_scope
# append the final child_scope yielded
if child_scope:
scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
+ scope.cte_sources.update(sources)
def _is_derived_table(expression: exp.Subquery) -> bool:
@@ -725,7 +724,7 @@ def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
- scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
+ scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d4e2e60..6ae08d0 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import functools
import itertools
@@ -6,10 +8,17 @@ from collections import deque
from decimal import Decimal
import sqlglot
-from sqlglot import exp
+from sqlglot import Dialect, exp
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
+ DateTruncBinaryTransform = t.Callable[
+ [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
+ ]
+
# Final means that an expression should not be simplified
FINAL = "final"
@@ -18,7 +27,9 @@ class UnsupportedUnit(Exception):
pass
-def simplify(expression, constant_propagation=False):
+def simplify(
+ expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
+):
"""
Rewrite sqlglot AST to simplify expressions.
@@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
+ dialect = Dialect.get_or_raise(dialect)
+
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
for group in expression.find_all(exp.Group):
select = group.parent
+ assert select
groups = set(group.expressions)
group.meta[FINAL] = True
- for e in select.selects:
+ for e in select.expressions:
for node, *_ in e.walk():
if node in groups:
e.meta[FINAL] = True
@@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False):
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
- node = simplify_datetrunc_predicate(node)
+ node = simplify_datetrunc(node, dialect)
+ node = sort_comparison(node)
if root:
expression.replace(node)
@@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
- return exp.and_(
+ negate = isinstance(expression.parent, exp.Not)
+
+ expression = exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
+
+ if negate:
+ expression = exp.paren(expression, copy=False)
+
return expression
+COMPLEMENT_COMPARISONS = {
+ exp.LT: exp.GTE,
+ exp.GT: exp.LTE,
+ exp.LTE: exp.GT,
+ exp.GTE: exp.LT,
+ exp.EQ: exp.NEQ,
+ exp.NEQ: exp.EQ,
+}
+
+
def simplify_not(expression):
"""
Demorgan's Law
@@ -132,10 +163,15 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if is_null(expression.this):
+ this = expression.this
+ if is_null(this):
return exp.null()
- if isinstance(expression.this, exp.Paren):
- condition = expression.this.unnest()
+ if this.__class__ in COMPLEMENT_COMPARISONS:
+ return COMPLEMENT_COMPARISONS[this.__class__](
+ this=this.this, expression=this.expression
+ )
+ if isinstance(this, exp.Paren):
+ condition = this.unnest()
if isinstance(condition, exp.And):
return exp.or_(
exp.not_(condition.left, copy=False),
@@ -150,14 +186,14 @@ def simplify_not(expression):
)
if is_null(condition):
return exp.null()
- if always_true(expression.this):
+ if always_true(this):
return exp.false()
- if is_false(expression.this):
+ if is_false(this):
return exp.true()
- if isinstance(expression.this, exp.Not):
+ if isinstance(this, exp.Not):
# double negation
# NOT NOT x -> x
- return expression.this.this
+ return this.this
return expression
@@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False):
except StopIteration:
return expression
- # make sure the comparison is always of the form x > 1 instead of 1 < x
- if left.__class__ in INVERSE_COMPARISONS and l == ll:
- left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
- if right.__class__ in INVERSE_COMPARISONS and r == rl:
- right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
-
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
@@ -397,13 +427,7 @@ def propagate_constants(expression, root=True):
# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
- pass
- elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
- l, r = r, l
- else:
- continue
-
- constant_mapping[l] = (id(l), r)
+ constant_mapping[l] = (id(l), r)
if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
@@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
- if l.__class__ in INVERSE_OPS:
- pass
- elif r.__class__ in INVERSE_OPS:
- l, r = r, l
- else:
+ if not l.__class__ in INVERSE_OPS:
return expression
if r.is_number:
@@ -650,7 +670,7 @@ def simplify_coalesce(expression):
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
- if _is_constant(other):
+ if _is_constant(arg):
break
else:
return expression
@@ -752,7 +772,7 @@ def simplify_conditionals(expression):
DateRange = t.Tuple[datetime.date, datetime.date]
-def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
+def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
@@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
- floor = date_floor(date, unit)
+ floor = date_floor(date, unit, dialect)
if date != floor:
# This will always be False, except for NULL values.
@@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp
def _datetrunc_eq(
- left: exp.Expression, date: datetime.date, unit: str
+ left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@@ -790,9 +810,9 @@ def _datetrunc_eq(
def _datetrunc_neq(
- left: exp.Expression, date: datetime.date, unit: str
+ left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@@ -803,41 +823,39 @@ def _datetrunc_neq(
)
-DateTruncBinaryTransform = t.Callable[
- [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
-]
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
- exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
- exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
- exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
- exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
+ exp.LT: lambda l, dt, u, d: l
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
+ exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
+ exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
+ exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
+DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
- return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
+ return isinstance(left, DATETRUNCS) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
-def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
+def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__
- if comparison not in DATETRUNC_COMPARISONS:
+ if isinstance(expression, DATETRUNCS):
+ date = extract_date(expression.this)
+ if date and expression.unit:
+ return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
+ elif comparison not in DATETRUNC_COMPARISONS:
return expression
if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right
- if _is_datetrunc_predicate(l, r):
- pass
- elif _is_datetrunc_predicate(r, l):
- comparison = INVERSE_COMPARISONS.get(comparison, comparison)
- l, r = r, l
- else:
+ if not _is_datetrunc_predicate(l, r):
return expression
l = t.cast(exp.DateTrunc, l)
@@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
if not date:
return expression
- return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
+ return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
date = extract_date(r)
if not date:
return expression
- drange = _datetrunc_range(date, unit)
+ drange = _datetrunc_range(date, unit, dialect)
if drange:
ranges.append(drange)
@@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
return expression
+def sort_comparison(expression: exp.Expression) -> exp.Expression:
+ if expression.__class__ in COMPLEMENT_COMPARISONS:
+ l, r = expression.this, expression.expression
+ l_column = isinstance(l, exp.Column)
+ r_column = isinstance(r, exp.Column)
+ l_const = _is_constant(l)
+ r_const = _is_constant(r)
+
+ if (l_column and not r_column) or (r_const and not l_const):
+ return expression
+ if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
+ return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
+ this=r, expression=l
+ )
+ return expression
+
+
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
@@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1):
raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_floor(d: datetime.date, unit: str) -> datetime.date:
+def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
@@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date:
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
- return d - datetime.timedelta(days=d.weekday())
+ return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
if unit == "day":
return d
raise UnsupportedUnit(f"Unsupported unit: {unit}")
-def date_ceil(d: datetime.date, unit: str) -> datetime.date:
- floor = date_floor(d, unit)
+def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
+ floor = date_floor(d, unit, dialect)
if floor == d:
return d
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 242fc87..4d35175 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name):
)
):
column = exp.Max(this=column)
+ elif not isinstance(select.parent, exp.Subquery):
+ return
_replace(select.parent, column)
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index c7e27a3..3d01a84 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -568,6 +568,7 @@ class Parser(metaclass=_Parser):
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
+ exp.When: lambda self: seq_get(self._parse_when_matched(), 0),
exp.Where: lambda self: self._parse_where(),
exp.Window: lambda self: self._parse_named_window(),
exp.With: lambda self: self._parse_with(),
@@ -635,6 +636,11 @@ class Parser(metaclass=_Parser):
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
exp.RawString, this=token.text
),
+ TokenType.UNICODE_STRING: lambda self, token: self.expression(
+ exp.UnicodeString,
+ this=token.text,
+ escape=self._match_text_seq("UESCAPE") and self._parse_string(),
+ ),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@@ -907,7 +913,7 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
CLONE_KEYWORDS = {"CLONE", "COPY"}
- CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
+ HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}
@@ -947,6 +953,10 @@ class Parser(metaclass=_Parser):
# Whether the TRIM function expects the characters to trim as its first argument
TRIM_PATTERN_FIRST = False
+ # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
+ MODIFIERS_ATTACHED_TO_UNION = True
+ UNION_MODIFIERS = {"order", "limit", "offset"}
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1162,6 +1172,9 @@ class Parser(metaclass=_Parser):
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[start.start : end.end + 1]
+ def _is_connected(self) -> bool:
+ return self._prev and self._curr and self._prev.end + 1 == self._curr.start
+
def _advance(self, times: int = 1) -> None:
self._index += times
self._curr = seq_get(self._tokens, self._index)
@@ -1404,23 +1417,8 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.CLONE_KEYWORDS):
copy = self._prev.text.lower() == "copy"
- clone = self._parse_table(schema=True)
- when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper()
- clone_kind = (
- self._match(TokenType.L_PAREN)
- and self._match_texts(self.CLONE_KINDS)
- and self._prev.text.upper()
- )
- clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
- self._match(TokenType.R_PAREN)
clone = self.expression(
- exp.Clone,
- this=clone,
- when=when,
- kind=clone_kind,
- shallow=shallow,
- expression=clone_expression,
- copy=copy,
+ exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
)
return self.expression(
@@ -2471,13 +2469,7 @@ class Parser(metaclass=_Parser):
pattern = None
define = (
- self._parse_csv(
- lambda: self.expression(
- exp.Alias,
- alias=self._parse_id_var(any_token=True),
- this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
- )
- )
+ self._parse_csv(self._parse_name_as_expression)
if self._match_text_seq("DEFINE")
else None
)
@@ -3124,6 +3116,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Connect, start=start, connect=connect)
+ def _parse_name_as_expression(self) -> exp.Alias:
+ return self.expression(
+ exp.Alias,
+ alias=self._parse_id_var(any_token=True),
+ this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
+ )
+
+ def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]:
+ if self._match_text_seq("INTERPOLATE"):
+ return self._parse_wrapped_csv(self._parse_name_as_expression)
+ return None
+
def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
@@ -3131,7 +3135,10 @@ class Parser(metaclass=_Parser):
return this
return self.expression(
- exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
+ exp.Order,
+ this=this,
+ expressions=self._parse_csv(self._parse_ordered),
+ interpolate=self._parse_interpolate(),
)
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
@@ -3161,7 +3168,21 @@ class Parser(metaclass=_Parser):
):
nulls_first = True
- return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
+ if self._match_text_seq("WITH", "FILL"):
+ with_fill = self.expression(
+ exp.WithFill,
+ **{ # type: ignore
+ "from": self._match(TokenType.FROM) and self._parse_bitwise(),
+ "to": self._match_text_seq("TO") and self._parse_bitwise(),
+ "step": self._match_text_seq("STEP") and self._parse_bitwise(),
+ },
+ )
+ else:
+ with_fill = None
+
+ return self.expression(
+ exp.Ordered, this=this, desc=desc, nulls_first=nulls_first, with_fill=with_fill
+ )
def _parse_limit(
self, this: t.Optional[exp.Expression] = None, top: bool = False
@@ -3253,28 +3274,40 @@ class Parser(metaclass=_Parser):
return locks
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
- if not self._match_set(self.SET_OPERATIONS):
- return this
+ while this and self._match_set(self.SET_OPERATIONS):
+ token_type = self._prev.token_type
- token_type = self._prev.token_type
+ if token_type == TokenType.UNION:
+ operation = exp.Union
+ elif token_type == TokenType.EXCEPT:
+ operation = exp.Except
+ else:
+ operation = exp.Intersect
- if token_type == TokenType.UNION:
- expression = exp.Union
- elif token_type == TokenType.EXCEPT:
- expression = exp.Except
- else:
- expression = exp.Intersect
+ comments = self._prev.comments
+ distinct = self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL)
+ by_name = self._match_text_seq("BY", "NAME")
+ expression = self._parse_select(nested=True, parse_set_operation=False)
- return self.expression(
- expression,
- comments=self._prev.comments,
- this=this,
- distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
- by_name=self._match_text_seq("BY", "NAME"),
- expression=self._parse_set_operations(
- self._parse_select(nested=True, parse_set_operation=False)
- ),
- )
+ this = self.expression(
+ operation,
+ comments=comments,
+ this=this,
+ distinct=distinct,
+ by_name=by_name,
+ expression=expression,
+ )
+
+ if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION:
+ expression = this.expression
+
+ if expression:
+ for arg in self.UNION_MODIFIERS:
+ expr = expression.args.get(arg)
+ if expr:
+ this.set(arg, expr.pop())
+
+ return this
def _parse_expression(self) -> t.Optional[exp.Expression]:
return self._parse_alias(self._parse_conjunction())
@@ -3595,7 +3628,7 @@ class Parser(metaclass=_Parser):
exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
)
else:
- this = self.expression(exp.Interval, unit=unit)
+ this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit))
if maybe_func and check_func:
index2 = self._index
@@ -4891,8 +4924,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
- def _advance_any(self) -> t.Optional[Token]:
- if self._curr and self._curr.token_type not in self.RESERVED_TOKENS:
+ def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
+ if self._curr and (ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS):
self._advance()
return self._prev
return None
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 07ee739..bbc52ab 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -425,16 +425,27 @@ class SetOperation(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
assert isinstance(expression, exp.Union)
+
left = Step.from_expression(expression.left, ctes)
+ # SELECT 1 UNION SELECT 2 <-- these subqueries don't have names
+ left.name = left.name or "left"
right = Step.from_expression(expression.right, ctes)
+ right.name = right.name or "right"
step = cls(
op=expression.__class__,
left=left.name,
right=right.name,
distinct=bool(expression.args.get("distinct")),
)
+
step.add_dependency(left)
step.add_dependency(right)
+
+ limit = expression.args.get("limit")
+
+ if limit:
+ step.limit = int(limit.text("expression"))
+
return step
def _to_s(self, indent: str) -> t.List[str]:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index e4c3204..de9d4c4 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -1,9 +1,10 @@
from __future__ import annotations
+import os
import typing as t
from enum import auto
-from sqlglot.errors import TokenError
+from sqlglot.errors import SqlglotError, TokenError
from sqlglot.helper import AutoName
from sqlglot.trie import TrieResult, in_trie, new_trie
@@ -11,6 +12,19 @@ if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
+try:
+ from sqlglotrs import ( # type: ignore
+ Tokenizer as RsTokenizer,
+ TokenizerDialectSettings as RsTokenizerDialectSettings,
+ TokenizerSettings as RsTokenizerSettings,
+ TokenTypeSettings as RsTokenTypeSettings,
+ )
+
+ USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1"
+except ImportError:
+ USE_RS_TOKENIZER = False
+
+
class TokenType(AutoName):
L_PAREN = auto()
R_PAREN = auto()
@@ -83,6 +97,7 @@ class TokenType(AutoName):
NATIONAL_STRING = auto()
RAW_STRING = auto()
HEREDOC_STRING = auto()
+ UNICODE_STRING = auto()
# types
BIT = auto()
@@ -347,6 +362,10 @@ class TokenType(AutoName):
TIMESTAMP_SNAPSHOT = auto()
+_ALL_TOKEN_TYPES = list(TokenType)
+_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)}
+
+
class Token:
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
@@ -432,6 +451,7 @@ class _Tokenizer(type):
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
+ **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS),
}
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
@@ -455,6 +475,46 @@ class _Tokenizer(type):
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
+ if USE_RS_TOKENIZER:
+ settings = RsTokenizerSettings(
+ white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()},
+ single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()},
+ keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()},
+ numeric_literals=klass.NUMERIC_LITERALS,
+ identifiers=klass._IDENTIFIERS,
+ identifier_escapes=klass._IDENTIFIER_ESCAPES,
+ string_escapes=klass._STRING_ESCAPES,
+ quotes=klass._QUOTES,
+ format_strings={
+ k: (v1, _TOKEN_TYPE_TO_INDEX[v2])
+ for k, (v1, v2) in klass._FORMAT_STRINGS.items()
+ },
+ has_bit_strings=bool(klass.BIT_STRINGS),
+ has_hex_strings=bool(klass.HEX_STRINGS),
+ comments=klass._COMMENTS,
+ var_single_tokens=klass.VAR_SINGLE_TOKENS,
+ commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS},
+ command_prefix_tokens={
+ _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
+ },
+ )
+ token_types = RsTokenTypeSettings(
+ bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
+ break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
+ dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
+ heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
+ hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
+ identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
+ number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
+ parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER],
+ semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
+ string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
+ var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
+ )
+ klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
+ else:
+ klass._RS_TOKENIZER = None
+
return klass
@@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer):
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
+ UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
@@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer):
_QUOTES: t.Dict[str, str] = {}
_STRING_ESCAPES: t.Set[str] = set()
_KEYWORD_TRIE: t.Dict = {}
+ _RS_TOKENIZER: t.Optional[t.Any] = None
KEYWORDS: t.Dict[str, TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
@@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
# handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}
- ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
@@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
+ "_rs_dialect_settings",
)
def __init__(self, dialect: DialectType = None) -> None:
from sqlglot.dialects import Dialect
self.dialect = Dialect.get_or_raise(dialect)
+
+ if USE_RS_TOKENIZER:
+ self._rs_dialect_settings = RsTokenizerDialectSettings(
+ escape_sequences=self.dialect.ESCAPE_SEQUENCES,
+ identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
+ )
+
self.reset()
def reset(self) -> None:
@@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer):
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
+ if USE_RS_TOKENIZER:
+ return self.tokenize_rs(sql)
+
self.reset()
self.sql = sql
self.size = len(sql)
@@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer):
# Ensures we don't count an extra line if we get a \r\n line break sequence
if self._char == "\r" and self._peek == "\n":
i = 2
+ self._start += 1
self._col = 1
self._line += 1
@@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer):
raise TokenError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
- else:
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(token_type, text)
return True
@@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer):
text += self.sql[current : self._current - 1]
return text
+
+ def tokenize_rs(self, sql: str) -> t.List[Token]:
+ if not self._RS_TOKENIZER:
+ raise SqlglotError("Rust tokenizer is not available")
+
+ try:
+ tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings)
+ for token in tokens:
+ token.token_type = _ALL_TOKEN_TYPES[token.token_type_index]
+ return tokens
+ except Exception as e:
+ raise TokenError(str(e))