summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py6
-rw-r--r--sqlglot/__main__.py18
-rw-r--r--sqlglot/dataframe/__init__.py3
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi20
-rw-r--r--sqlglot/dataframe/sql/dataframe.py2
-rw-r--r--sqlglot/dialects/bigquery.py20
-rw-r--r--sqlglot/dialects/clickhouse.py10
-rw-r--r--sqlglot/dialects/dialect.py14
-rw-r--r--sqlglot/dialects/drill.py14
-rw-r--r--sqlglot/dialects/duckdb.py18
-rw-r--r--sqlglot/dialects/hive.py18
-rw-r--r--sqlglot/dialects/mysql.py14
-rw-r--r--sqlglot/dialects/oracle.py6
-rw-r--r--sqlglot/dialects/postgres.py72
-rw-r--r--sqlglot/dialects/presto.py15
-rw-r--r--sqlglot/dialects/redshift.py1
-rw-r--r--sqlglot/dialects/snowflake.py56
-rw-r--r--sqlglot/dialects/spark.py12
-rw-r--r--sqlglot/dialects/sqlite.py6
-rw-r--r--sqlglot/dialects/starrocks.py2
-rw-r--r--sqlglot/dialects/tableau.py2
-rw-r--r--sqlglot/dialects/tsql.py10
-rw-r--r--sqlglot/diff.py4
-rw-r--r--sqlglot/executor/context.py4
-rw-r--r--sqlglot/executor/env.py4
-rw-r--r--sqlglot/executor/python.py13
-rw-r--r--sqlglot/expressions.py104
-rw-r--r--sqlglot/generator.py92
-rw-r--r--sqlglot/optimizer/canonicalize.py3
-rw-r--r--sqlglot/optimizer/eliminate_joins.py49
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py4
-rw-r--r--sqlglot/optimizer/normalize.py5
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py14
-rw-r--r--sqlglot/optimizer/pushdown_projections.py6
-rw-r--r--sqlglot/optimizer/qualify_columns.py3
-rw-r--r--sqlglot/optimizer/quote_identities.py25
-rw-r--r--sqlglot/optimizer/scope.py30
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py27
-rw-r--r--sqlglot/parser.py243
-rw-r--r--sqlglot/schema.py4
-rw-r--r--sqlglot/tokens.py27
42 files changed, 727 insertions, 275 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index e829517..04c3195 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -1,4 +1,6 @@
-"""## Python SQL parser, transpiler and optimizer."""
+"""
+.. include:: ../README.md
+"""
from __future__ import annotations
@@ -30,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.2.9"
+__version__ = "10.4.2"
pretty = False
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index 42a54bc..f9613b2 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -1,9 +1,15 @@
import argparse
+import sys
import sqlglot
parser = argparse.ArgumentParser(description="Transpile SQL")
-parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile")
+parser.add_argument(
+ "sql",
+ metavar="sql",
+ type=str,
+ help="SQL statement(s) to transpile, or - to parse stdin.",
+)
parser.add_argument(
"--read",
dest="read",
@@ -48,14 +54,20 @@ parser.add_argument(
args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
+sql = sys.stdin.read() if args.sql == "-" else args.sql
+
if args.parse:
sqls = [
repr(expression)
- for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
+ for expression in sqlglot.parse(
+ sql,
+ read=args.read,
+ error_level=error_level,
+ )
]
else:
sqls = sqlglot.transpile(
- args.sql,
+ sql,
read=args.read,
write=args.write,
identify=args.identify,
diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py
index e69de29..a57e990 100644
--- a/sqlglot/dataframe/__init__.py
+++ b/sqlglot/dataframe/__init__.py
@@ -0,0 +1,3 @@
+"""
+.. include:: ./README.md
+"""
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
index 67c8c09..1682ec1 100644
--- a/sqlglot/dataframe/sql/_typing.pyi
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -9,18 +9,8 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.types import StructType
-ColumnLiterals = t.TypeVar(
- "ColumnLiterals",
- bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
-)
-ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
-ColumnOrLiteral = t.TypeVar(
- "ColumnOrLiteral",
- bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
-)
-SchemaInput = t.TypeVar(
- "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
-)
-OutputExpressionContainer = t.TypeVar(
- "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
-)
+ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ColumnOrName = t.Union[Column, str]
+ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
+OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 3c45741..a17bb9d 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -634,7 +634,7 @@ class DataFrame:
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
if isinstance(value, dict):
- values = value.values()
+ values = list(value.values())
columns = self._ensure_and_normalize_cols(list(value))
if not columns:
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 6be68ac..d10cc54 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,11 +1,15 @@
+"""Supports BigQuery Standard SQL."""
+
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ datestrtodate_sql,
inline_array_sql,
no_ilike_sql,
rename_func,
+ timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -120,13 +124,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
- "WINDOW": TokenType.WINDOW,
}
KEYWORDS.pop("DIV")
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
@@ -144,31 +147,33 @@ class BigQuery(Dialect):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS, # type: ignore
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
- **parser.Parser.NO_PAREN_FUNCTIONS,
+ **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
- *parser.Parser.NESTED_TYPE_TOKENS,
+ *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
TokenType.TABLE,
}
class Generator(generator.Generator):
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
+ exp.DateStrToDate: datestrtodate_sql,
+ exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@@ -176,6 +181,7 @@ class BigQuery(Dialect):
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
+ exp.TimeStrToTime: timestrtotime_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
@@ -188,7 +194,7 @@ class BigQuery(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index cbed72e..7136340 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -35,13 +35,13 @@ class ClickHouse(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map,
}
- JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
+ JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
- TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
+ TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
@@ -55,7 +55,7 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@@ -70,7 +70,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index c87f8d8..e788852 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
def _rename(self, expression):
args = flatten(expression.args.values())
- return f"{name}({self.format_args(*args)})"
+ return f"{self.normalize_func(name)}({self.format_args(*args)})"
return _rename
@@ -217,11 +217,11 @@ def if_sql(self, expression):
def arrow_json_extract_sql(self, expression):
- return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}"
+ return self.binary(expression, "->")
def arrow_json_extract_scalar_sql(self, expression):
- return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}"
+ return self.binary(expression, "->>")
def inline_array_sql(self, expression):
@@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression):
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
+
+
+def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
+ return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
+
+
+def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
+ return f"CAST({self.sql(expression, 'this')} AS DATE)"
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 358eced..4e3c0e1 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
+ datestrtodate_sql,
format_time_lambda,
no_pivot_sql,
no_trycast_sql,
rename_func,
str_position_sql,
+ timestrtotime_sql,
)
-from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args):
@@ -117,14 +118,14 @@ class Drill(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
@@ -139,14 +140,13 @@ class Drill(Dialect):
ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
- exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"),
- exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
@@ -160,7 +160,7 @@ class Drill(Dialect):
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
- exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f1da72b..81941f7 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
+ datestrtodate_sql,
format_time_lambda,
no_pivot_sql,
no_properties_sql,
@@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
rename_func,
str_position_sql,
+ timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -83,11 +85,12 @@ class DuckDB(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
+ "CHARACTER VARYING": TokenType.VARCHAR,
}
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@@ -119,16 +122,18 @@ class DuckDB(Dialect):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.Array: rename_func("LIST_VALUE"),
+ exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
+ if isinstance(seq_get(e.expressions, 0), exp.Select)
+ else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
- exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
@@ -136,6 +141,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
+ exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@@ -150,7 +156,7 @@ class DuckDB(Dialect):
exp.Struct: _struct_pack_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
- exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
@@ -163,7 +169,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 8d6e1ae..088555c 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
rename_func,
strposition_to_local_sql,
struct_extract_sql,
+ timestrtotime_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
@@ -197,7 +198,7 @@ class Hive(Dialect):
STRICT_CAST = False
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
@@ -217,7 +218,12 @@ class Hive(Dialect):
),
unit=exp.Literal.string("DAY"),
),
- "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
+ "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
+ [
+ exp.TimeStrToTime(this=seq_get(args, 0)),
+ seq_get(args, 1),
+ ]
+ ),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
@@ -240,7 +246,7 @@ class Hive(Dialect):
}
PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS,
+ **parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
@@ -248,14 +254,14 @@ class Hive(Dialect):
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
@@ -294,7 +300,7 @@ class Hive(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
- exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 7627b6e..0fd7992 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -161,8 +161,6 @@ class MySQL(Dialect):
"_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER,
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
- "N": TokenType.INTRODUCER,
- "n": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER,
@@ -175,10 +173,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
- FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
@@ -190,7 +188,7 @@ class MySQL(Dialect):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS, # type: ignore
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -199,12 +197,12 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
- **parser.Parser.PROPERTY_PARSERS,
+ **parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS,
+ **parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
@@ -429,7 +427,7 @@ class MySQL(Dialect):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index f507513..af3d353 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -39,13 +39,13 @@ class Oracle(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"DECODE": exp.Matches.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@@ -60,7 +60,7 @@ class Oracle(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index f276af1..a092cad 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
+from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
+DATE_DIFF_FACTOR = {
+ "MICROSECOND": " * 1000000",
+ "MILLISECOND": " * 1000",
+ "SECOND": "",
+ "MINUTE": " / 60",
+ "HOUR": " / 3600",
+ "DAY": " / 86400",
+}
+
def _date_add_sql(kind):
def func(self, expression):
@@ -34,16 +44,30 @@ def _date_add_sql(kind):
return func
-def _lateral_sql(self, expression):
- this = self.sql(expression, "this")
- if isinstance(expression.this, exp.Subquery):
- return f"LATERAL{self.sep()}{this}"
- alias = expression.args["alias"]
- table = alias.name
- table = f" {table}" if table else table
- columns = self.expressions(alias, key="columns", flat=True)
- columns = f" AS {columns}" if columns else ""
- return f"LATERAL{self.sep()}{this}{table}{columns}"
+def _date_diff_sql(self, expression):
+ unit = expression.text("unit").upper()
+ factor = DATE_DIFF_FACTOR.get(unit)
+
+ end = f"CAST({expression.this} AS TIMESTAMP)"
+ start = f"CAST({expression.expression} AS TIMESTAMP)"
+
+ if factor is not None:
+ return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
+
+ age = f"AGE({end}, {start})"
+
+ if unit == "WEEK":
+ extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
+ elif unit == "MONTH":
+ extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
+ elif unit == "QUARTER":
+ extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
+ elif unit == "YEAR":
+ extract = f"EXTRACT(year FROM {age})"
+ else:
+ self.unsupported(f"Unsupported DATEDIFF unit {unit}")
+
+ return f"CAST({extract} AS BIGINT)"
def _substring_sql(self, expression):
@@ -141,7 +165,7 @@ def _serial_to_generated(expression):
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
- if len(args) == 1 and args[0].is_number:
+ if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
@@ -211,11 +235,16 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "~~": TokenType.LIKE,
+ "~~*": TokenType.ILIKE,
+ "~*": TokenType.IRLIKE,
+ "~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
+ "CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
@@ -233,6 +262,7 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
+ "CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@@ -244,17 +274,16 @@ class Postgres(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
- LATERAL_FUNCTION_AS_VIEW = True
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@@ -264,7 +293,7 @@ class Postgres(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
@@ -274,13 +303,16 @@ class Postgres(Dialect):
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
- exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
- exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
+ exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
+ exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
+ exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
- exp.Lateral: _lateral_sql,
+ exp.DateDiff: _date_diff_sql,
+ exp.RegexpLike: lambda self, e: self.binary(e, "~"),
+ exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
@@ -291,5 +323,7 @@ class Postgres(Dialect):
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
- exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
+ exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
+ if isinstance(seq_get(e.expressions, 0), exp.Select)
+ else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
}
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 1a09037..e16ea1d 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
struct_extract_sql,
+ timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
@@ -38,10 +39,6 @@ def _datatype_sql(self, expression):
return sql
-def _date_parse_sql(self, expression):
- return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
-
-
def _explode_to_unnest_sql(self, expression):
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
@@ -137,7 +134,7 @@ class Presto(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
@@ -174,7 +171,7 @@ class Presto(Dialect):
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@@ -184,7 +181,7 @@ class Presto(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@@ -224,8 +221,8 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
- exp.TimeStrToDate: _date_parse_sql,
- exp.TimeStrToTime: _date_parse_sql,
+ exp.TimeStrToDate: timestrtotime_sql,
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 55ed0a6..27dfb93 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -36,7 +36,6 @@ class Redshift(Postgres):
"TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
- "SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 75dc9dc..77b09e9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -3,13 +3,15 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ datestrtodate_sql,
format_time_lambda,
inline_array_sql,
rename_func,
+ timestrtotime_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
-from sqlglot.helper import seq_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@@ -183,7 +185,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
- ESCAPES = ["\\"]
+ ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@@ -206,9 +208,10 @@ class Snowflake(Dialect):
CREATE_TRANSIENT = True
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -218,13 +221,14 @@ class Snowflake(Dialect):
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -246,3 +250,47 @@ class Snowflake(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
+
+ def values_sql(self, expression: exp.Values) -> str:
+ """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
+
+ We also want to make sure that after we find matches where we need to unquote a column that we prevent users
+ from adding quotes to the column by using the `identify` argument when generating the SQL.
+ """
+ alias = expression.args.get("alias")
+ if alias and alias.args.get("columns"):
+ expression = expression.transform(
+ lambda node: exp.Identifier(**{**node.args, "quoted": False})
+ if isinstance(node, exp.Identifier)
+ and isinstance(node.parent, exp.TableAlias)
+ and node.arg_key == "columns"
+ else node,
+ )
+ return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
+ return super().values_sql(expression)
+
+ def select_sql(self, expression: exp.Select) -> str:
+ """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
+ that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
+ to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
+ generating the SQL.
+
+ Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
+ expression. This might not be true in a case where the same column name can be sourced from another table that can
+ properly quote but should be true in most cases.
+ """
+ values_expressions = expression.find_all(exp.Values)
+ values_identifiers = set(
+ flatten(
+ v.args.get("alias", exp.Alias()).args.get("columns", [])
+ for v in values_expressions
+ )
+ )
+ if values_identifiers:
+ expression = expression.transform(
+ lambda node: exp.Identifier(**{**node.args, "quoted": False})
+ if isinstance(node, exp.Identifier) and node in values_identifiers
+ else node,
+ )
+ return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
+ return super().select_sql(expression)
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 16083d1..7f05dea 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -76,7 +76,7 @@ class Spark(Hive):
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -87,6 +87,16 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
+ def _parse_add_column(self):
+ return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
+
+ def _parse_drop_column(self):
+ return self._match_text_seq("DROP", "COLUMNS") and self.expression(
+ exp.Drop,
+ this=self._parse_schema(),
+ kind="COLUMNS",
+ )
+
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index bbb752b..a0c4942 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -42,13 +42,13 @@ class SQLite(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@@ -70,7 +70,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 3519c09..01e6357 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
- **MySQL.Generator.TYPE_MAPPING,
+ **MySQL.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 63e7275..36c085f 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -30,7 +30,7 @@ class Tableau(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a552e7b..7f0f2d7 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -224,11 +224,7 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
- QUOTES = [
- (prefix + quote, quote) if prefix else quote
- for quote in ["'", '"']
- for prefix in ["", "n", "N"]
- ]
+ QUOTES = ["'", '"']
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -253,7 +249,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@@ -314,7 +310,7 @@ class TSQL(Dialect):
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 758ad1b..fa8bc1b 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -1,3 +1,7 @@
+"""
+.. include:: ../posts/sql_diff.md
+"""
+
from __future__ import annotations
import typing as t
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index e9ff75b..8a58287 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -29,10 +29,10 @@ class Context:
self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
- self.env = {**(env or {}), "scope": self.row_readers}
+ self.env = {**ENV, **(env or {}), "scope": self.row_readers}
def eval(self, code):
- return eval(code, ENV, self.env)
+ return eval(code, self.env)
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index ad9397e..04dc938 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -127,14 +127,16 @@ def interval(this, unit):
ENV = {
"exp": exp,
# aggs
- "SUM": filter_nulls(sum),
+ "ARRAYAGG": list,
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
+ "SUM": filter_nulls(sum),
# scalar functions
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
+ "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 9f22c45..29848c6 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -394,6 +394,18 @@ def _case_sql(self, expression):
return chain
+def _lambda_sql(self, e: exp.Lambda) -> str:
+ names = {e.name.lower() for e in e.expressions}
+
+ e = e.transform(
+ lambda n: exp.Var(this=n.name)
+ if isinstance(n, exp.Identifier) and n.name.lower() in names
+ else n
+ )
+
+ return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
+
+
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
@@ -414,6 +426,7 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
+ exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index aeed218..711ec4b 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,6 +1,11 @@
+"""
+.. include:: ../pdoc/docs/expressions.md
+"""
+
from __future__ import annotations
import datetime
+import math
import numbers
import re
import typing as t
@@ -682,6 +687,10 @@ class CharacterSet(Expression):
class With(Expression):
arg_types = {"expressions": True, "recursive": False}
+ @property
+ def recursive(self) -> bool:
+ return bool(self.args.get("recursive"))
+
class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
@@ -724,6 +733,18 @@ class ColumnDef(Expression):
"this": True,
"kind": True,
"constraints": False,
+ "exists": False,
+ }
+
+
+class AlterColumn(Expression):
+ arg_types = {
+ "this": True,
+ "dtype": False,
+ "collate": False,
+ "using": False,
+ "default": False,
+ "drop": False,
}
@@ -877,6 +898,11 @@ class Introducer(Expression):
arg_types = {"this": True, "expression": True}
+# national char, like n'utf8'
+class National(Expression):
+ pass
+
+
class LoadData(Expression):
arg_types = {
"this": True,
@@ -894,7 +920,7 @@ class Partition(Expression):
class Fetch(Expression):
- arg_types = {"direction": False, "count": True}
+ arg_types = {"direction": False, "count": False}
class Group(Expression):
@@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = {
"group": False,
"having": False,
"qualify": False,
- "window": False,
+ "windows": False,
"distribute": False,
"sort": False,
"cluster": False,
@@ -1353,7 +1379,7 @@ class Union(Subqueryable):
Example:
>>> select("1").union(select("1")).limit(1).sql()
- 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
+ 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
Args:
expression (str | int | Expression): the SQL code string to parse.
@@ -1889,6 +1915,18 @@ class Select(Subqueryable):
**opts,
)
+ def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
+ return _apply_list_builder(
+ *expressions,
+ instance=self,
+ arg="windows",
+ append=append,
+ into=Window,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@@ -2140,6 +2178,11 @@ class DataType(Expression):
)
+# https://www.postgresql.org/docs/15/datatype-pseudo.html
+class PseudoType(Expression):
+ pass
+
+
class StructKwarg(Expression):
arg_types = {"this": True, "expression": True}
@@ -2167,18 +2210,26 @@ class Command(Expression):
arg_types = {"this": True, "expression": False}
-class Transaction(Command):
+class Transaction(Expression):
arg_types = {"this": False, "modes": False}
-class Commit(Command):
+class Commit(Expression):
arg_types = {"chain": False}
-class Rollback(Command):
+class Rollback(Expression):
arg_types = {"savepoint": False}
+class AlterTable(Expression):
+ arg_types = {
+ "this": True,
+ "actions": True,
+ "exists": False,
+ }
+
+
# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate):
pass
+class Slice(Binary):
+ arg_types = {"this": False, "expression": False}
+
+
class Sub(Binary):
pass
@@ -2392,7 +2447,7 @@ class TimeUnit(Expression):
class Interval(TimeUnit):
- arg_types = {"this": True, "unit": False}
+ arg_types = {"this": False, "unit": False}
class IgnoreNulls(Expression):
@@ -2730,8 +2785,11 @@ class Initcap(Func):
pass
-class JSONExtract(Func):
- arg_types = {"this": True, "path": True}
+class JSONBContains(Binary):
+ _sql_names = ["JSONB_CONTAINS"]
+
+
+class JSONExtract(Binary, Func):
_sql_names = ["JSON_EXTRACT"]
@@ -2776,6 +2834,10 @@ class Log10(Func):
pass
+class LogicalOr(AggFunc):
+ _sql_names = ["LOGICAL_OR", "BOOL_OR"]
+
+
class Lower(Func):
_sql_names = ["LOWER", "LCASE"]
@@ -2846,6 +2908,10 @@ class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
+class RegexpILike(Func):
+ arg_types = {"this": True, "expression": True, "flag": False}
+
+
class RegexpSplit(Func):
arg_types = {"this": True, "expression": True}
@@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
],
)
if from_:
- update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
+ update.set(
+ "from",
+ maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
+ )
if isinstance(where, Condition):
where = Where(this=where)
if where:
- update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
+ update.set(
+ "where",
+ maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
+ )
return update
@@ -3522,7 +3594,7 @@ def paren(expression) -> Paren:
return Paren(this=expression)
-SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
+SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
@@ -3724,6 +3796,8 @@ def convert(value) -> Expression:
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
+ if isinstance(value, float) and math.isnan(value):
+ return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
@@ -3732,11 +3806,13 @@ def convert(value) -> Expression:
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
- keys=[convert(k) for k in value.keys()],
+ keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
- datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z"))
+ datetime_literal = Literal.string(
+ (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
+ )
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 2b4c575..0c1578a 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -361,10 +361,11 @@ class Generator:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
+ exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
if not constraints:
- return f"{column} {kind}"
- return f"{column} {kind} {constraints}"
+ return f"{exists}{column} {kind}"
+ return f"{exists}{column} {kind} {constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
@@ -549,6 +550,9 @@ class Generator:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
+ def national_sql(self, expression: exp.National) -> str:
+ return f"N{self.sql(expression, 'this')}"
+
def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
@@ -633,6 +637,9 @@ class Generator:
def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
+ def pseudotype_sql(self, expression: exp.PseudoType) -> str:
+ return expression.name.upper()
+
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
@@ -793,19 +800,17 @@ class Generator:
if isinstance(expression.this, exp.Subquery):
return f"LATERAL {this}"
- alias = expression.args["alias"]
- table = alias.name
- columns = self.expressions(alias, key="columns", flat=True)
-
if expression.args.get("view"):
- table = f" {table}" if table else table
+ alias = expression.args["alias"]
+ columns = self.expressions(alias, key="columns", flat=True)
+ table = f" {alias.name}" if alias.name else ""
columns = f" AS {columns}" if columns else ""
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"{op_sql}{self.sep()}{this}{table}{columns}"
- table = f" AS {table}" if table else table
- columns = f"({columns})" if columns else ""
- return f"LATERAL {this}{table}{columns}"
+ alias = self.sql(expression, "alias")
+ alias = f" AS {alias}" if alias else ""
+ return f"LATERAL {this}{alias}"
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
@@ -891,13 +896,15 @@ class Generator:
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
- *[self.sql(sql) for sql in expression.args.get("joins", [])],
- *[self.sql(sql) for sql in expression.args.get("laterals", [])],
+ *[self.sql(sql) for sql in expression.args.get("joins") or []],
+ *[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
- self.sql(expression, "window"),
+ self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
+ if expression.args.get("windows")
+ else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
@@ -1008,11 +1015,7 @@ class Generator:
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
-
- if expression.arg_key == "window":
- this = this = f"{self.seg('WINDOW')} {this} AS"
- else:
- this = f"{this} OVER"
+ this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
if not partition and not order and not spec and alias:
return f"{this} {alias}"
@@ -1141,9 +1144,11 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
- return f"INTERVAL {self.sql(expression, 'this')}{unit}"
+ return f"INTERVAL{this}{unit}"
def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
@@ -1245,6 +1250,43 @@ class Generator:
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
+ def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
+ this = self.sql(expression, "this")
+
+ dtype = self.sql(expression, "dtype")
+ if dtype:
+ collate = self.sql(expression, "collate")
+ collate = f" COLLATE {collate}" if collate else ""
+ using = self.sql(expression, "using")
+ using = f" USING {using}" if using else ""
+ return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
+
+ default = self.sql(expression, "default")
+ if default:
+ return f"ALTER COLUMN {this} SET DEFAULT {default}"
+
+ if not expression.args.get("drop"):
+ self.unsupported("Unsupported ALTER COLUMN syntax")
+
+ return f"ALTER COLUMN {this} DROP DEFAULT"
+
+ def altertable_sql(self, expression: exp.AlterTable) -> str:
+ actions = expression.args["actions"]
+
+ if isinstance(actions[0], exp.ColumnDef):
+ actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
+ elif isinstance(actions[0], exp.Schema):
+ actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
+ elif isinstance(actions[0], exp.Drop):
+ actions = self.expressions(expression, "actions")
+ elif isinstance(actions[0], exp.AlterColumn):
+ actions = self.sql(actions[0])
+ else:
+ self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
+
+ exists = " IF EXISTS" if expression.args.get("exists") else ""
+ return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
+
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1327,6 +1369,9 @@ class Generator:
def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
+ def slice_sql(self, expression: exp.Slice) -> str:
+ return self.binary(expression, ":")
+
def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
@@ -1369,6 +1414,7 @@ class Generator:
flat: bool = False,
indent: bool = True,
sep: str = ", ",
+ prefix: str = "",
) -> str:
expressions = expression.args.get(key or "expressions")
@@ -1391,11 +1437,13 @@ class Generator:
if self.pretty:
if self._leading_comma:
- result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
+ result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
- result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
+ result_sqls.append(
+ f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
+ )
else:
- result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
+ result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sql, skip_first=False) if indent else result_sql
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index 33529a5..fc37a54 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
+ if isinstance(expression, exp.Identifier):
+ expression.set("quoted", True)
+
return expression
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index de4e011..3b40710 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -129,10 +129,23 @@ def join_condition(join):
"""
name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
- on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
+ def extract_condition(condition):
+ left, right = condition.unnest_operands()
+ left_tables = exp.column_table_names(left)
+ right_tables = exp.column_table_names(right)
+
+ if name in left_tables and name not in right_tables:
+ join_key.append(left)
+ source_key.append(right)
+ condition.replace(exp.true())
+ elif name in right_tables and name not in left_tables:
+ join_key.append(right)
+ source_key.append(left)
+ condition.replace(exp.true())
+
# find the join keys
# SELECT
# FROM x
@@ -141,20 +154,30 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
+
for condition in on.flatten():
if isinstance(condition, exp.EQ):
- left, right = condition.unnest_operands()
- left_tables = exp.column_table_names(left)
- right_tables = exp.column_table_names(right)
-
- if name in left_tables and name not in right_tables:
- join_key.append(left)
- source_key.append(right)
- condition.replace(exp.true())
- elif name in right_tables and name not in left_tables:
- join_key.append(right)
- source_key.append(left)
- condition.replace(exp.true())
+ extract_condition(condition)
+ elif normalized(on, dnf=True):
+ conditions = None
+
+ for condition in on.flatten():
+ parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
+ if conditions is None:
+ conditions = parts
+ else:
+ temp = []
+ for p in parts:
+ cs = [c for c in conditions if p == c]
+
+ if cs:
+ temp.append(p)
+ temp.extend(cs)
+ conditions = temp
+
+ for condition in conditions:
+ extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 39e252c..2245cc2 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {}
with_ = root.expression.args.get("with")
+ recursive = False
if with_:
+ recursive = with_.args.get("recursive")
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
@@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)
if new_ctes:
- expression.set("with", exp.With(expressions=new_ctes))
+ expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index db538ef..f16f519 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
- return x
+ return [
+ a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
+ ]
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 6819717..72e67d4 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
-from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
@@ -34,7 +33,6 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
- quote_identities,
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index f92e5c3..ba5c8b5 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -27,7 +27,14 @@ def pushdown_predicates(expression):
select = scope.expression
where = select.args.get("where")
if where:
- pushdown(where.this, scope.selected_sources, scope_ref_count)
+ selected_sources = scope.selected_sources
+ # a right join can only push down to itself and not the source FROM table
+ for k, (node, source) in selected_sources.items():
+ parent = node.find_ancestor(exp.Join, exp.From)
+ if isinstance(parent, exp.Join) and parent.side == "RIGHT":
+ selected_sources = {k: (node, source)}
+ break
+ pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
+ with_ = source.parent.expression.args.get("with")
+ if with_ and with_.recursive:
+ return {}
node = source.expression
if isinstance(node, exp.Join):
- if node.side:
+ if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index abd9492..49789ac 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
-# SELECTION TO USE IF SELECTION LIST IS EMPTY
+# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
@@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections.append(DEFAULT_SELECTION)
+ new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
return removed_indexes
@@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
- new_selections.append(DEFAULT_SELECTION)
+ new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index e6e6dc9..e16a635 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -311,6 +311,9 @@ def _qualify_outputs(scope):
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
+ elif isinstance(selection, exp.Subquery):
+ if not selection.alias:
+ selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)
diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py
deleted file mode 100644
index 17623cc..0000000
--- a/sqlglot/optimizer/quote_identities.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from sqlglot import exp
-
-
-def quote_identities(expression):
- """
- Rewrite sqlglot AST to ensure all identities are quoted.
-
- Example:
- >>> import sqlglot
- >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
- >>> quote_identities(expression).sql()
- 'SELECT "x"."a" AS "a" FROM "db"."x"'
-
- Args:
- expression (sqlglot.Expression): expression to quote
- Returns:
- sqlglot.Expression: quoted expression
- """
-
- def qualify(node):
- if isinstance(node, exp.Identifier):
- node.set("quoted", True)
- return node
-
- return expression.transform(qualify, copy=False)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 18848f3..6125e4e 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -511,9 +511,20 @@ def _traverse_union(scope):
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
+ is_cte = scope_type == ScopeType.CTE
for derived_table in derived_tables:
- top = None
+ recursive_scope = None
+
+ # if the scope is a recursive cte, it must be in the form of
+ # base_case UNION recursive. thus the recursive scope is the first
+ # section of the union.
+ if is_cte and scope.expression.args["with"].recursive:
+ union = derived_table.this
+
+ if isinstance(union, exp.Union):
+ recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
+
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
@@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
- top = child_scope
+
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
- sources[derived_table.alias] = child_scope
- if scope_type == ScopeType.CTE:
- scope.cte_scopes.append(top)
+ alias = derived_table.alias
+ sources[alias] = child_scope
+
+ if recursive_scope:
+ child_scope.add_source(alias, recursive_scope)
+
+ # append the final child_scope yielded
+ if is_cte:
+ scope.cte_scopes.append(child_scope)
else:
- scope.derived_table_scopes.append(top)
+ scope.derived_table_scopes.append(child_scope)
+
scope.sources.update(sources)
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 2046917..8d78294 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
- AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
+ AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
@@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence)
keys = []
- # for all external columns in the where statement,
- # split out the relevant data to convert it into a join
+ # for all external columns in the where statement, find the relevant predicate
+ # keys to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
@@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
+ is_subquery_projection = any(
+ node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
+ )
+
value = select.selects[0]
key_aliases = {}
group_by = []
@@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
- # so that it can be grouped
+ # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
+ agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by:
- select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
+ select.select(
+ exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
+ append=False,
+ copy=False,
+ )
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
- select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
+ select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
@@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
+ if is_subquery_projection:
+ alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
+ if is_subquery_projection:
+ key.replace(nested)
+ continue
+
if key in group_by:
key.replace(nested)
parent_predicate = _replace(
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 29bc9c0..308f363 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
+from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@@ -117,6 +117,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
TokenType.HSTORE,
+ TokenType.PSEUDO_TYPE,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
@@ -153,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.CACHE,
TokenType.CASCADE,
TokenType.COLLATE,
+ TokenType.COLUMN,
TokenType.COMMAND,
TokenType.COMMIT,
TokenType.COMPOUND,
@@ -169,6 +171,7 @@ class Parser(metaclass=_Parser):
TokenType.ESCAPE,
TokenType.FALSE,
TokenType.FIRST,
+ TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FUNCTION,
@@ -188,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
+ TokenType.OFFSET,
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
@@ -222,12 +226,18 @@ class Parser(metaclass=_Parser):
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VOLATILE,
+ TokenType.WINDOW,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
}
- TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
+ TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
+ TokenType.APPLY,
+ TokenType.NATURAL,
+ TokenType.OFFSET,
+ TokenType.WINDOW,
+ }
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
@@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
+ TokenType.WINDOW,
*TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@@ -351,22 +362,27 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
- path=path,
+ expression=path,
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
- path=path,
+ expression=path,
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
this=this,
- path=path,
+ expression=path,
),
TokenType.DHASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtractScalar,
this=this,
- path=path,
+ expression=path,
+ ),
+ TokenType.PLACEHOLDER: lambda self, this, key: self.expression(
+ exp.JSONBContains,
+ this=this,
+ expression=key,
),
}
@@ -392,25 +408,27 @@ class Parser(metaclass=_Parser):
exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(),
+ exp.Window: lambda self: self._parse_named_window(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
STATEMENT_PARSERS = {
+ TokenType.ALTER: lambda self: self._parse_alter(),
+ TokenType.BEGIN: lambda self: self._parse_transaction(),
+ TokenType.CACHE: lambda self: self._parse_cache(),
+ TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(),
+ TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
+ TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
- TokenType.UPDATE: lambda self: self._parse_update(),
- TokenType.DELETE: lambda self: self._parse_delete(),
- TokenType.CACHE: lambda self: self._parse_cache(),
+ TokenType.MERGE: lambda self: self._parse_merge(),
+ TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
+ TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
- TokenType.BEGIN: lambda self: self._parse_transaction(),
- TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
- TokenType.END: lambda self: self._parse_commit_or_rollback(),
- TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
- TokenType.MERGE: lambda self: self._parse_merge(),
}
UNARY_PARSERS = {
@@ -441,6 +459,7 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
+ TokenType.NATIONAL: lambda self, token: self._parse_national(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@@ -454,6 +473,9 @@ class Parser(metaclass=_Parser):
TokenType.ILIKE: lambda self, this: self._parse_escape(
self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
),
+ TokenType.IRLIKE: lambda self, this: self.expression(
+ exp.RegexpILike, this=this, expression=self._parse_bitwise()
+ ),
TokenType.RLIKE: lambda self, this: self.expression(
exp.RegexpLike, this=this, expression=self._parse_bitwise()
),
@@ -535,8 +557,7 @@ class Parser(metaclass=_Parser):
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
- "window": lambda self: self._match(TokenType.WINDOW)
- and self._parse_window(self._parse_id_var(), alias=True),
+ "windows": lambda self: self._parse_window_clause(),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@@ -551,18 +572,18 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {
- TokenType.TABLE,
- TokenType.VIEW,
+ TokenType.COLUMN,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
TokenType.SCHEMA,
+ TokenType.TABLE,
+ TokenType.VIEW,
}
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
STRICT_CAST = True
- LATERAL_FUNCTION_AS_VIEW = False
__slots__ = (
"error_level",
@@ -782,13 +803,16 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
- def _parse_drop(self):
+ def _parse_drop(self, default_kind=None):
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
- self.raise_error(f"Expected {self.CREATABLES}")
- return
+ if default_kind:
+ kind = default_kind
+ else:
+ self.raise_error(f"Expected {self.CREATABLES}")
+ return
return self.expression(
exp.Drop,
@@ -876,7 +900,7 @@ class Parser(metaclass=_Parser):
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
if assignment:
- key = self._parse_var() or self._parse_string()
+ key = self._parse_var_or_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
@@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser):
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this)
+ this = self._parse_set_operations(this)
self._match_r_paren()
- this = self._parse_subquery(this)
+ # early return so that subquery unions aren't parsed again
+ # SELECT * FROM (SELECT 1) UNION ALL SELECT 1
+ # Union ALL should be a property of the top select node, not the subquery
+ return self._parse_subquery(this)
elif self._match(TokenType.VALUES):
+ if self._curr.token_type == TokenType.L_PAREN:
+ # We don't consume the left paren because it's consumed in _parse_value
+ expressions = self._parse_csv(self._parse_value)
+ else:
+ # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
+ # Source: https://prestodb.io/docs/current/sql/values.html
+ expressions = self._parse_csv(
+ lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
+ )
+
this = self.expression(
exp.Values,
- expressions=self._parse_csv(self._parse_value),
+ expressions=expressions,
alias=self._parse_table_alias(),
)
else:
this = None
- return self._parse_set_operations(this) if this else None
+ return self._parse_set_operations(this)
def _parse_with(self, skip_with_token=False):
if not skip_with_token and not self._match(TokenType.WITH):
@@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser):
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
)
- columns = None
if self._match(TokenType.L_PAREN):
- columns = self._parse_csv(lambda: self._parse_id_var(any_token))
+ columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
self._match_r_paren()
+ else:
+ columns = None
if not alias and not columns:
return None
@@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
- columns = None
- table_alias = None
- if view or self.LATERAL_FUNCTION_AS_VIEW:
- table_alias = self._parse_id_var(any_token=False)
- if self._match(TokenType.ALIAS):
- columns = self._parse_csv(self._parse_id_var)
+ if view:
+ table = self._parse_id_var(any_token=False)
+ columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
+ table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
else:
- self._match(TokenType.ALIAS)
- table_alias = self._parse_id_var(any_token=False)
-
- if self._match(TokenType.L_PAREN):
- columns = self._parse_csv(self._parse_id_var)
- self._match_r_paren()
+ table_alias = self._parse_table_alias()
expression = self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
- alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
+ alias=table_alias,
)
if outer_apply or cross_apply:
@@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser):
if negate:
this = self.expression(exp.Not, this=this)
+ if self._match(TokenType.IS):
+ this = self._parse_is(this)
+
return this
def _parse_is(self, this):
@@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser):
return None
type_token = self._prev.token_type
+
+ if type_token == TokenType.PSEUDO_TYPE:
+ return self.expression(exp.PseudoType, this=self._prev.text)
+
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT
expressions = None
@@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser):
if value is None:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
+ elif type_token == TokenType.INTERVAL:
+ value = self.expression(exp.Interval, unit=self._parse_var())
if maybe_func and check_func:
index2 = self._index
@@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser):
def _parse_primary(self):
if self._match_set(self.PRIMARY_PARSERS):
- return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
+ token_type = self._prev.token_type
+ primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
+
+ if token_type == TokenType.STRING:
+ expressions = [primary]
+ while self._match(TokenType.STRING):
+ expressions.append(exp.Literal.string(self._prev.text))
+ if len(expressions) > 1:
+ return self.expression(exp.Concat, expressions=expressions)
+ return primary
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}")
@@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=token.text)
+ def _parse_national(self, token):
+ return self.expression(exp.National, this=exp.Literal.string(token.text))
+
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
@@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_id_var)
- self._match(TokenType.R_PAREN)
+
+ if not self._match(TokenType.R_PAREN):
+ self._retreat(index)
else:
expressions = [self._parse_id_var()]
@@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser):
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else:
- this = self._parse_conjunction()
+ this = self._parse_select_or_expression()
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
else:
self._match(TokenType.RESPECT_NULLS)
- return self._parse_alias(self._parse_limit(self._parse_order(this)))
+ return self._parse_limit(self._parse_order(this))
def _parse_schema(self, this=None):
index = self._index
@@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser):
return this
args = self._parse_csv(
- lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
+ lambda: self._parse_constraint()
+ or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.ENCODE):
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
- kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
+ kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.NULL):
@@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_BRACKET):
return this
- expressions = self._parse_csv(self._parse_conjunction)
+ if self._match(TokenType.COLON):
+ expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
+ else:
+ expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
if not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
@@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser):
this.comments = self._prev_comments
return self._parse_bracket(this)
+ def _parse_slice(self, this):
+ if self._match(TokenType.COLON):
+ return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
+ return this
+
def _parse_case(self):
ifs = []
default = None
@@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser):
collation=collation,
)
+ def _parse_window_clause(self):
+ return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
+
+ def _parse_named_window(self):
+ return self._parse_window(self._parse_id_var(), alias=True)
+
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where)
@@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser):
if identifier:
return identifier
- if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
- self._advance()
- elif not self._match_set(tokens or self.ID_VAR_TOKENS):
- return None
- return exp.Identifier(this=self._prev.text, quoted=False)
+ if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
+ return exp.Identifier(this=self._prev.text, quoted=False)
+ return None
def _parse_string(self):
if self._match(TokenType.STRING):
@@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
- def _parse_var(self):
- if self._match(TokenType.VAR):
+ def _parse_var(self, any_token=False):
+ if (any_token and self._advance_any()) or self._match(TokenType.VAR):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
+ def _advance_any(self):
+ if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
+ self._advance()
+ return self._prev
+ return None
+
def _parse_var_or_string(self):
return self._parse_var() or self._parse_string()
@@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.PLACEHOLDER):
return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
- self._advance()
- return self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match_set((TokenType.NUMBER, TokenType.VAR)):
+ return self.expression(exp.Placeholder, this=self._prev.text)
+ self._advance(-1)
return None
def _parse_except(self):
@@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit, chain=chain)
+ def _parse_add_column(self):
+ if not self._match_text_seq("ADD"):
+ return None
+
+ self._match(TokenType.COLUMN)
+ exists_column = self._parse_exists(not_=True)
+ expression = self._parse_column_def(self._parse_field(any_token=True))
+ expression.set("exists", exists_column)
+ return expression
+
+ def _parse_drop_column(self):
+ return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
+
+ def _parse_alter(self):
+ if not self._match(TokenType.TABLE):
+ return None
+
+ exists = self._parse_exists()
+ this = self._parse_table(schema=True)
+
+ actions = None
+ if self._match_text_seq("ADD", advance=False):
+ actions = self._parse_csv(self._parse_add_column)
+ elif self._match_text_seq("DROP", advance=False):
+ actions = self._parse_csv(self._parse_drop_column)
+ elif self._match_text_seq("ALTER"):
+ self._match(TokenType.COLUMN)
+ column = self._parse_field(any_token=True)
+
+ if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
+ actions = self.expression(exp.AlterColumn, this=column, drop=True)
+ elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
+ actions = self.expression(
+ exp.AlterColumn, this=column, default=self._parse_conjunction()
+ )
+ else:
+ self._match_text_seq("SET", "DATA")
+ actions = self.expression(
+ exp.AlterColumn,
+ this=column,
+ dtype=self._match_text_seq("TYPE") and self._parse_types(),
+ collate=self._match(TokenType.COLLATE) and self._parse_term(),
+ using=self._match(TokenType.USING) and self._parse_conjunction(),
+ )
+
+ actions = ensure_list(actions)
+ return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
+
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
@@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser):
return True
return False
- def _match_text_seq(self, *texts):
+ def _match_text_seq(self, *texts, advance=True):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:
@@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser):
else:
self._retreat(index)
return False
+
+ if not advance:
+ self._retreat(index)
+
return True
def _replace_columns_with_dots(self, this):
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index c223ee0..d9a4004 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
- self._type_mapping_cache: t.Dict[str, exp.DataType] = {
- "STR": exp.DataType.build("text"),
- }
+ self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index b25ef8d..0efa7d0 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -48,6 +48,7 @@ class TokenType(AutoName):
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()
+ NATIONAL = auto()
BLOCK_START = auto()
BLOCK_END = auto()
@@ -111,6 +112,7 @@ class TokenType(AutoName):
# keywords
ALIAS = auto()
+ ALTER = auto()
ALWAYS = auto()
ALL = auto()
ANTI = auto()
@@ -196,6 +198,7 @@ class TokenType(AutoName):
INTERVAL = auto()
INTO = auto()
INTRODUCER = auto()
+ IRLIKE = auto()
IS = auto()
ISNULL = auto()
JOIN = auto()
@@ -241,6 +244,7 @@ class TokenType(AutoName):
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
+ PSEUDO_TYPE = auto()
QUALIFY = auto()
QUOTE = auto()
RANGE = auto()
@@ -346,7 +350,11 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs)
- klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
+ klass._QUOTES = {
+ f"{prefix}{s}": e
+ for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
+ for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
+ }
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
@@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
+ "COLUMN": TokenType.COLUMN,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
@@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SEMI": TokenType.SEMI,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
+ "SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
@@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
+ "WINDOW": TokenType.WINDOW,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
@@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
+ "STR": TokenType.TEXT,
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
@@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
- "ALTER": TokenType.COMMAND,
+ "ALTER": TokenType.ALTER,
+ "ALTER AGGREGATE": TokenType.COMMAND,
+ "ALTER DEFAULT": TokenType.COMMAND,
+ "ALTER DOMAIN": TokenType.COMMAND,
+ "ALTER ROLE": TokenType.COMMAND,
+ "ALTER RULE": TokenType.COMMAND,
+ "ALTER SEQUENCE": TokenType.COMMAND,
+ "ALTER TYPE": TokenType.COMMAND,
+ "ALTER USER": TokenType.COMMAND,
+ "ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
@@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.replace("\\\\", "\\") if self._replace_backslash else text
- self._add(TokenType.STRING, text)
+ self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.