summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/clickhouse.py13
-rw-r--r--sqlglot/dialects/dialect.py25
-rw-r--r--sqlglot/dialects/doris.py65
-rw-r--r--sqlglot/dialects/duckdb.py26
-rw-r--r--sqlglot/dialects/hive.py14
-rw-r--r--sqlglot/dialects/mysql.py8
-rw-r--r--sqlglot/dialects/postgres.py5
-rw-r--r--sqlglot/dialects/presto.py10
-rw-r--r--sqlglot/dialects/redshift.py7
-rw-r--r--sqlglot/dialects/spark.py3
-rw-r--r--sqlglot/dialects/starrocks.py5
-rw-r--r--sqlglot/executor/__init__.py10
-rw-r--r--sqlglot/executor/table.py34
-rw-r--r--sqlglot/expressions.py58
-rw-r--r--sqlglot/generator.py28
-rw-r--r--sqlglot/optimizer/simplify.py108
-rw-r--r--sqlglot/parser.py118
-rw-r--r--sqlglot/schema.py41
-rw-r--r--sqlglot/tokens.py13
20 files changed, 465 insertions, 127 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index fc34262..8212669 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -60,6 +60,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
+from sqlglot.dialects.doris import Doris
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index e6b7743..cfde5fd 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -37,17 +37,22 @@ class ClickHouse(Dialect):
"ATTACH": TokenType.COMMAND,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
+ "ENUM": TokenType.ENUM,
+ "ENUM8": TokenType.ENUM8,
+ "ENUM16": TokenType.ENUM16,
"FINAL": TokenType.FINAL,
+ "FIXEDSTRING": TokenType.FIXEDSTRING,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"GLOBAL": TokenType.GLOBAL,
- "INT128": TokenType.INT128,
"INT16": TokenType.SMALLINT,
"INT256": TokenType.INT256,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
+ "LOWCARDINALITY": TokenType.LOWCARDINALITY,
"MAP": TokenType.MAP,
+ "NESTED": TokenType.NESTED,
"TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
"UINT16": TokenType.USMALLINT,
@@ -294,11 +299,17 @@ class ClickHouse(Dialect):
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
+ exp.DataType.Type.ENUM: "Enum",
+ exp.DataType.Type.ENUM8: "Enum8",
+ exp.DataType.Type.ENUM16: "Enum16",
+ exp.DataType.Type.FIXEDSTRING: "FixedString",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.INT256: "Int256",
+ exp.DataType.Type.LOWCARDINALITY: "LowCardinality",
exp.DataType.Type.MAP: "Map",
+ exp.DataType.Type.NESTED: "Nested",
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 1d0584c..132496f 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -39,6 +39,7 @@ class Dialects(str, Enum):
TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
+ Doris = "doris"
class _Dialect(type):
@@ -121,7 +122,7 @@ class _Dialect(type):
if hasattr(subclass, name):
setattr(subclass, name, value)
- if not klass.STRICT_STRING_CONCAT:
+ if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
klass.generator_class.can_identify = klass.can_identify
@@ -146,6 +147,9 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
+ # Determines whether or not the DPIPE token ('||') is a string concatenation operator
+ DPIPE_IS_STRING_CONCAT = True
+
# Determines whether or not CONCAT's arguments must be strings
STRICT_STRING_CONCAT = False
@@ -460,6 +464,20 @@ def format_time_lambda(
return _format_time
+def time_format(
+ dialect: DialectType = None,
+) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
+ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
+ """
+ Returns the time format for a given expression, unless it's equivalent
+ to the default time format of the dialect of interest.
+ """
+ time_format = self.format_time(expression)
+ return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
+
+ return _time_format
+
+
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
@@ -699,3 +717,8 @@ def simplify_literal(expression: E) -> E:
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
+
+
+# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
+def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
+ return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
new file mode 100644
index 0000000..160c23c
--- /dev/null
+++ b/sqlglot/dialects/doris.py
@@ -0,0 +1,65 @@
+from __future__ import annotations
+
+from sqlglot import exp
+from sqlglot.dialects.dialect import (
+ approx_count_distinct_sql,
+ arrow_json_extract_sql,
+ parse_timestamp_trunc,
+ rename_func,
+ time_format,
+)
+from sqlglot.dialects.mysql import MySQL
+
+
+class Doris(MySQL):
+ DATE_FORMAT = "'yyyy-MM-dd'"
+ DATEINT_FORMAT = "'yyyyMMdd'"
+ TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
+
+ class Parser(MySQL.Parser):
+ FUNCTIONS = {
+ **MySQL.Parser.FUNCTIONS,
+ "DATE_TRUNC": parse_timestamp_trunc,
+ "REGEXP": exp.RegexpLike.from_arg_list,
+ }
+
+ class Generator(MySQL.Generator):
+ CAST_MAPPING = {}
+
+ TYPE_MAPPING = {
+ **MySQL.Generator.TYPE_MAPPING,
+ exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.TIMESTAMP: "DATETIME",
+ exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
+ }
+
+ TRANSFORMS = {
+ **MySQL.Generator.TRANSFORMS,
+ exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.ArrayAgg: rename_func("COLLECT_LIST"),
+ exp.Coalesce: rename_func("NVL"),
+ exp.CurrentTimestamp: lambda *_: "NOW()",
+ exp.DateTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
+ ),
+ exp.JSONExtractScalar: arrow_json_extract_sql,
+ exp.JSONExtract: arrow_json_extract_sql,
+ exp.RegexpLike: rename_func("REGEXP"),
+ exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
+ exp.SetAgg: rename_func("COLLECT_SET"),
+ exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Split: rename_func("SPLIT_BY_STRING"),
+ exp.TimeStrToDate: rename_func("TO_DATE"),
+ exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
+ exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
+ exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimestampTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
+ ),
+ exp.UnixToStr: lambda self, e: self.func(
+ "FROM_UNIXTIME", e.this, time_format("doris")(self, e)
+ ),
+ exp.UnixToTime: rename_func("FROM_UNIXTIME"),
+ exp.Map: rename_func("ARRAY_MAP"),
+ }
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 5428e86..8253b52 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -89,6 +89,11 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
+
+ # Type TIMESTAMP / TIME WITH TIME ZONE does not support any modifiers
+ if expression.is_type("timestamptz", "timetz"):
+ return expression.this.value
+
return self.datatype_sql(expression)
@@ -110,14 +115,14 @@ class DuckDB(Dialect):
"//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
- "BPCHAR": TokenType.TEXT,
"BITSTRING": TokenType.BIT,
+ "BPCHAR": TokenType.TEXT,
"CHAR": TokenType.TEXT,
"CHARACTER VARYING": TokenType.TEXT,
"EXCLUDE": TokenType.EXCEPT,
+ "HUGEINT": TokenType.INT128,
"INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
- "NUMERIC": TokenType.DOUBLE,
"PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
@@ -186,6 +191,22 @@ class DuckDB(Dialect):
TokenType.UTINYINT,
}
+ def _parse_types(
+ self, check_func: bool = False, schema: bool = False
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_types(check_func=check_func, schema=schema)
+
+ # DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
+ # See: https://duckdb.org/docs/sql/data_types/numeric
+ if (
+ isinstance(this, exp.DataType)
+ and this.is_type("numeric", "decimal")
+ and not this.expressions
+ ):
+ return exp.DataType.build("DECIMAL(18, 3)")
+
+ return this
+
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
@@ -231,6 +252,7 @@ class DuckDB(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
+ exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONFormat: _json_format_sql,
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index aa4d845..584acc6 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
right_to_substring_sql,
strposition_to_locate_sql,
struct_extract_sql,
+ time_format,
timestrtotime_sql,
var_map_sql,
)
@@ -113,7 +114,7 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
- return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
+ return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
@@ -132,15 +133,6 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
return f"CAST({this} AS TIMESTAMP)"
-def _time_format(
- self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
-) -> t.Optional[str]:
- time_format = self.format_time(expression)
- if time_format == Hive.TIME_FORMAT:
- return None
- return time_format
-
-
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@@ -439,7 +431,7 @@ class Hive(Dialect):
exp.TsOrDsToDate: _to_date_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToStr: lambda self, e: self.func(
- "FROM_UNIXTIME", e.this, _time_format(self, e)
+ "FROM_UNIXTIME", e.this, time_format("hive")(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 3cd99e7..9ab4ce8 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -94,6 +94,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
TIME_FORMAT = "'%Y-%m-%d %T'"
+ DPIPE_IS_STRING_CONCAT = False
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
TIME_MAPPING = {
@@ -103,7 +104,6 @@ class MySQL(Dialect):
"%h": "%I",
"%i": "%M",
"%s": "%S",
- "%S": "%S",
"%u": "%W",
"%k": "%-H",
"%l": "%-I",
@@ -196,8 +196,14 @@ class MySQL(Dialect):
**parser.Parser.CONJUNCTION,
TokenType.DAMP: exp.And,
TokenType.XOR: exp.Xor,
+ TokenType.DPIPE: exp.Or,
}
+ # MySQL uses || as a synonym to the logical OR operator
+ # https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or
+ BITWISE = parser.Parser.BITWISE.copy()
+ BITWISE.pop(TokenType.DPIPE)
+
TABLE_ALIAS_TOKENS = (
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
)
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index ca44b70..73ca4e5 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -16,6 +16,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
+ parse_timestamp_trunc,
rename_func,
simplify_literal,
str_position_sql,
@@ -286,9 +287,7 @@ class Postgres(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1), unit=seq_get(args, 0)
- ),
+ "DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 291b478..078da0b 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -32,13 +32,6 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
- sql = self.datatype_sql(expression)
- if expression.is_type("timestamptz"):
- sql = f"{sql} WITH TIME ZONE"
- return sql
-
-
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
@@ -231,6 +224,7 @@ class Presto(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
+ TZ_TO_WITH_TIME_ZONE = True
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@@ -245,6 +239,7 @@ class Presto(Dialect):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
+ exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
}
@@ -265,7 +260,6 @@ class Presto(Dialect):
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
- exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index cdb8d0d..30731e1 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -85,8 +85,6 @@ class Redshift(Postgres):
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
- "TIME": TokenType.TIMESTAMP,
- "TIMETZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
@@ -101,12 +99,15 @@ class Redshift(Postgres):
RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
+ TZ_TO_WITH_TIME_ZONE = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
- exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.TIMETZ: "TIME",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.VARBINARY: "VARBYTE",
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index b9aaa66..7c8982b 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -52,6 +52,9 @@ class Spark(Spark2):
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.StartsWith: rename_func("STARTSWITH"),
+ exp.TimestampAdd: lambda self, e: self.func(
+ "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
+ ),
}
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 4f6183c..2dba1c1 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -4,6 +4,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
+ parse_timestamp_trunc,
rename_func,
)
from sqlglot.dialects.mysql import MySQL
@@ -14,9 +15,7 @@ class StarRocks(MySQL):
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1), unit=seq_get(args, 0)
- ),
+ "DATE_TRUNC": parse_timestamp_trunc,
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index 017d5bc..304981b 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -28,6 +28,11 @@ if t.TYPE_CHECKING:
from sqlglot.schema import Schema
+PYTHON_TYPE_TO_SQLGLOT = {
+ "dict": "MAP",
+}
+
+
def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
@@ -50,7 +55,7 @@ def execute(
Returns:
Simple columnar data structure.
"""
- tables_ = ensure_tables(tables)
+ tables_ = ensure_tables(tables, dialect=read)
if not schema:
schema = {}
@@ -61,7 +66,8 @@ def execute(
assert table is not None
for column in table.columns:
- nested_set(schema, [*keys, column], type(table[0][column]).__name__)
+ py_type = type(table[0][column]).__name__
+ nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)
schema = ensure_schema(schema, dialect=read)
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 27e3e5e..74b9b7c 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -2,8 +2,9 @@ from __future__ import annotations
import typing as t
+from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import dict_depth
-from sqlglot.schema import AbstractMappingSchema
+from sqlglot.schema import AbstractMappingSchema, normalize_name
class Table:
@@ -108,26 +109,37 @@ class Tables(AbstractMappingSchema[Table]):
pass
-def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
- return Tables(_ensure_tables(d))
+def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
+ return Tables(_ensure_tables(d, dialect=dialect))
-def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
+def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict:
if not d:
return {}
depth = dict_depth(d)
-
if depth > 1:
- return {k: _ensure_tables(v) for k, v in d.items()}
+ return {
+ normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect)
+ for k, v in d.items()
+ }
result = {}
- for name, table in d.items():
+ for table_name, table in d.items():
+ table_name = normalize_name(table_name, dialect=dialect)
+
if isinstance(table, Table):
- result[name] = table
+ result[table_name] = table
else:
- columns = tuple(table[0]) if table else ()
- rows = [tuple(row[c] for c in columns) for row in table]
- result[name] = Table(columns=columns, rows=rows)
+ table = [
+ {
+ normalize_name(column_name, dialect=dialect): value
+ for column_name, value in row.items()
+ }
+ for row in table
+ ]
+ column_names = tuple(column_name for column_name in table[0]) if table else ()
+ rows = [tuple(row[name] for name in column_names) for row in table]
+ result[table_name] = Table(columns=column_names, rows=rows)
return result
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index c207751..57b8bfa 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -3309,6 +3309,7 @@ class Pivot(Expression):
"using": False,
"group": False,
"columns": False,
+ "include_nulls": False,
}
@@ -3397,23 +3398,16 @@ class DataType(Expression):
BOOLEAN = auto()
CHAR = auto()
DATE = auto()
+ DATEMULTIRANGE = auto()
+ DATERANGE = auto()
DATETIME = auto()
DATETIME64 = auto()
- ENUM = auto()
- INT4RANGE = auto()
- INT4MULTIRANGE = auto()
- INT8RANGE = auto()
- INT8MULTIRANGE = auto()
- NUMRANGE = auto()
- NUMMULTIRANGE = auto()
- TSRANGE = auto()
- TSMULTIRANGE = auto()
- TSTZRANGE = auto()
- TSTZMULTIRANGE = auto()
- DATERANGE = auto()
- DATEMULTIRANGE = auto()
DECIMAL = auto()
DOUBLE = auto()
+ ENUM = auto()
+ ENUM8 = auto()
+ ENUM16 = auto()
+ FIXEDSTRING = auto()
FLOAT = auto()
GEOGRAPHY = auto()
GEOMETRY = auto()
@@ -3421,23 +3415,31 @@ class DataType(Expression):
HSTORE = auto()
IMAGE = auto()
INET = auto()
- IPADDRESS = auto()
- IPPREFIX = auto()
INT = auto()
INT128 = auto()
INT256 = auto()
+ INT4MULTIRANGE = auto()
+ INT4RANGE = auto()
+ INT8MULTIRANGE = auto()
+ INT8RANGE = auto()
INTERVAL = auto()
+ IPADDRESS = auto()
+ IPPREFIX = auto()
JSON = auto()
JSONB = auto()
LONGBLOB = auto()
LONGTEXT = auto()
+ LOWCARDINALITY = auto()
MAP = auto()
MEDIUMBLOB = auto()
MEDIUMTEXT = auto()
MONEY = auto()
NCHAR = auto()
+ NESTED = auto()
NULL = auto()
NULLABLE = auto()
+ NUMMULTIRANGE = auto()
+ NUMRANGE = auto()
NVARCHAR = auto()
OBJECT = auto()
ROWVERSION = auto()
@@ -3450,19 +3452,24 @@ class DataType(Expression):
SUPER = auto()
TEXT = auto()
TIME = auto()
+ TIMETZ = auto()
TIMESTAMP = auto()
- TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
+ TIMESTAMPTZ = auto()
TINYINT = auto()
+ TSMULTIRANGE = auto()
+ TSRANGE = auto()
+ TSTZMULTIRANGE = auto()
+ TSTZRANGE = auto()
UBIGINT = auto()
UINT = auto()
- USMALLINT = auto()
- UTINYINT = auto()
- UNKNOWN = auto() # Sentinel value, useful for type annotation
UINT128 = auto()
UINT256 = auto()
UNIQUEIDENTIFIER = auto()
+ UNKNOWN = auto() # Sentinel value, useful for type annotation
USERDEFINED = "USER-DEFINED"
+ USMALLINT = auto()
+ UTINYINT = auto()
UUID = auto()
VARBINARY = auto()
VARCHAR = auto()
@@ -3495,6 +3502,7 @@ class DataType(Expression):
TEMPORAL_TYPES = {
Type.TIME,
+ Type.TIMETZ,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
@@ -3858,6 +3866,18 @@ class TimeUnit(Expression):
super().__init__(**args)
+# https://www.oracletutorial.com/oracle-basics/oracle-interval/
+# https://trino.io/docs/current/language/types.html#interval-year-to-month
+class IntervalYearToMonthSpan(Expression):
+ arg_types = {}
+
+
+# https://www.oracletutorial.com/oracle-basics/oracle-interval/
+# https://trino.io/docs/current/language/types.html#interval-day-to-second
+class IntervalDayToSecondSpan(Expression):
+ arg_types = {}
+
+
class Interval(TimeUnit):
arg_types = {"this": False, "unit": False}
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 95db795..f8d7d68 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -71,6 +71,8 @@ class Generator:
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
+ exp.IntervalDayToSecondSpan: "DAY TO SECOND",
+ exp.IntervalYearToMonthSpan: "YEAR TO MONTH",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
@@ -166,6 +168,9 @@ class Generator:
# Whether or not to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
+ # Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
+ TZ_TO_WITH_TIME_ZONE = False
+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@@ -271,10 +276,12 @@ class Generator:
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Create,
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
+ exp.Join,
exp.Select,
exp.Update,
exp.Where,
@@ -831,14 +838,17 @@ class Generator:
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
+
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
+
nested = ""
interior = self.expressions(expression, flat=True)
values = ""
+
if interior:
if expression.args.get("nested"):
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
@@ -846,10 +856,19 @@ class Generator:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = self.expressions(expression, key="values", flat=True)
values = f"{delimiters[0]}{values}{delimiters[1]}"
+ elif type_value == exp.DataType.Type.INTERVAL:
+ nested = f" {interior}"
else:
nested = f"({interior})"
- return f"{type_sql}{nested}{values}"
+ type_sql = f"{type_sql}{nested}{values}"
+ if self.TZ_TO_WITH_TIME_ZONE and type_value in (
+ exp.DataType.Type.TIMETZ,
+ exp.DataType.Type.TIMESTAMPTZ,
+ ):
+ type_sql = f"{type_sql} WITH TIME ZONE"
+
+ return type_sql
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
@@ -1288,7 +1307,12 @@ class Generator:
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
field = self.sql(expression, "field")
- return f"{direction}({expressions} FOR {field}){alias}"
+ include_nulls = expression.args.get("include_nulls")
+ if include_nulls is not None:
+ nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
+ else:
+ nulls = ""
+ return f"{direction}{nulls}({expressions} FOR {field}){alias}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index e247f58..e550603 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -54,11 +54,17 @@ def simplify(expression):
def _simplify(expression, root=True):
if expression.meta.get(FINAL):
return expression
+
+ # Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
+ node = simplify_concat(node)
+
exp.replace_children(node, lambda e: _simplify(e, False))
+
+ # Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
@@ -66,8 +72,11 @@ def simplify(expression):
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
+ node = simplify_coalesce(node)
+
if root:
expression.replace(node)
+
return node
expression = while_changing(expression, _simplify)
@@ -184,6 +193,7 @@ COMPARISONS = (
*GT_GTE,
exp.EQ,
exp.NEQ,
+ exp.Is,
)
INVERSE_COMPARISONS = {
@@ -430,6 +440,103 @@ def simplify_parens(expression):
return expression
+CONSTANTS = (
+ exp.Literal,
+ exp.Boolean,
+ exp.Null,
+)
+
+
+def simplify_coalesce(expression):
+ # COALESCE(x) -> x
+ if (
+ isinstance(expression, exp.Coalesce)
+ and not expression.expressions
+ # COALESCE is also used as a Spark partitioning hint
+ and not isinstance(expression.parent, exp.Hint)
+ ):
+ return expression.this
+
+ if not isinstance(expression, COMPARISONS):
+ return expression
+
+ if isinstance(expression.left, exp.Coalesce):
+ coalesce = expression.left
+ other = expression.right
+ elif isinstance(expression.right, exp.Coalesce):
+ coalesce = expression.right
+ other = expression.left
+ else:
+ return expression
+
+ # This transformation is valid for non-constants,
+ # but it really only does anything if they are both constants.
+ if not isinstance(other, CONSTANTS):
+ return expression
+
+ # Find the first constant arg
+ for arg_index, arg in enumerate(coalesce.expressions):
+ if isinstance(arg, CONSTANTS):
+ break
+ else:
+ return expression
+
+ coalesce.set("expressions", coalesce.expressions[:arg_index])
+
+ # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
+ # since we already remove COALESCE at the top of this function.
+ coalesce = coalesce if coalesce.expressions else coalesce.this
+
+ # This expression is more complex than when we started, but it will get simplified further
+ return exp.or_(
+ exp.and_(
+ coalesce.is_(exp.null()).not_(copy=False),
+ expression.copy(),
+ copy=False,
+ ),
+ exp.and_(
+ coalesce.is_(exp.null()),
+ type(expression)(this=arg.copy(), expression=other.copy()),
+ copy=False,
+ ),
+ copy=False,
+ )
+
+
+CONCATS = (exp.Concat, exp.DPipe)
+SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
+
+
+def simplify_concat(expression):
+ """Reduces all groups that contain string literals by concatenating them."""
+ if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
+ return expression
+
+ new_args = []
+ for is_string_group, group in itertools.groupby(
+ expression.expressions or expression.flatten(), lambda e: e.is_string
+ ):
+ if is_string_group:
+ new_args.append(exp.Literal.string("".join(string.name for string in group)))
+ else:
+ new_args.extend(group)
+
+ # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
+ concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+ return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
+
+
+# CROSS joins result in an empty table if the right table is empty.
+# So we can only simplify certain types of joins to CROSS.
+# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
+JOINS = {
+ ("", ""),
+ ("", "INNER"),
+ ("RIGHT", ""),
+ ("RIGHT", "OUTER"),
+}
+
+
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
@@ -439,6 +546,7 @@ def remove_where_true(expression):
always_true(join.args.get("on"))
and not join.args.get("using")
and not join.args.get("method")
+ and (join.side, join.kind) in JOINS
):
join.set("on", None)
join.set("side", None)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 35a1744..3db4453 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -102,15 +102,23 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
+ STRUCT_TYPE_TOKENS = {
+ TokenType.NESTED,
+ TokenType.STRUCT,
+ }
+
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
+ TokenType.LOWCARDINALITY,
TokenType.MAP,
TokenType.NULLABLE,
- TokenType.STRUCT,
+ *STRUCT_TYPE_TOKENS,
}
ENUM_TYPE_TOKENS = {
TokenType.ENUM,
+ TokenType.ENUM8,
+ TokenType.ENUM16,
}
TYPE_TOKENS = {
@@ -128,6 +136,7 @@ class Parser(metaclass=_Parser):
TokenType.UINT128,
TokenType.INT256,
TokenType.UINT256,
+ TokenType.FIXEDSTRING,
TokenType.FLOAT,
TokenType.DOUBLE,
TokenType.CHAR,
@@ -145,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.JSONB,
TokenType.INTERVAL,
TokenType.TIME,
+ TokenType.TIMETZ,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
@@ -187,7 +197,7 @@ class Parser(metaclass=_Parser):
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
- TokenType.ENUM,
+ *ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
}
@@ -384,11 +394,16 @@ class Parser(metaclass=_Parser):
TokenType.STAR: exp.Mul,
}
- TIMESTAMPS = {
+ TIMES = {
TokenType.TIME,
+ TokenType.TIMETZ,
+ }
+
+ TIMESTAMPS = {
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
+ *TIMES,
}
SET_OPERATIONS = {
@@ -1165,6 +1180,8 @@ class Parser(metaclass=_Parser):
def _parse_create(self) -> exp.Create | exp.Command:
# Note: this can't be None because we've matched a statement parser
start = self._prev
+ comments = self._prev_comments
+
replace = start.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
)
@@ -1273,6 +1290,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Create,
+ comments=comments,
this=this,
kind=create_token.text,
replace=replace,
@@ -2338,7 +2356,8 @@ class Parser(metaclass=_Parser):
kwargs["this"].set("joins", joins)
- return self.expression(exp.Join, **kwargs)
+ comments = [c for token in (method, side, kind) if token for c in token.comments]
+ return self.expression(exp.Join, comments=comments, **kwargs)
def _parse_index(
self,
@@ -2619,11 +2638,18 @@ class Parser(metaclass=_Parser):
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
+ include_nulls = None
if self._match(TokenType.PIVOT):
unpivot = False
elif self._match(TokenType.UNPIVOT):
unpivot = True
+
+ # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax
+ if self._match_text_seq("INCLUDE", "NULLS"):
+ include_nulls = True
+ elif self._match_text_seq("EXCLUDE", "NULLS"):
+ include_nulls = False
else:
return None
@@ -2654,7 +2680,13 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
- pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
+ pivot = self.expression(
+ exp.Pivot,
+ expressions=expressions,
+ field=field,
+ unpivot=unpivot,
+ include_nulls=include_nulls,
+ )
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
@@ -3096,7 +3128,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.PseudoType, this=self._prev.text)
nested = type_token in self.NESTED_TYPE_TOKENS
- is_struct = type_token == TokenType.STRUCT
+ is_struct = type_token in self.STRUCT_TYPE_TOKENS
expressions = None
maybe_func = False
@@ -3108,7 +3140,7 @@ class Parser(metaclass=_Parser):
lambda: self._parse_types(check_func=check_func, schema=schema)
)
elif type_token in self.ENUM_TYPE_TOKENS:
- expressions = self._parse_csv(self._parse_primary)
+ expressions = self._parse_csv(self._parse_equality)
else:
expressions = self._parse_csv(self._parse_type_size)
@@ -3118,29 +3150,9 @@ class Parser(metaclass=_Parser):
maybe_func = True
- if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- this = exp.DataType(
- this=exp.DataType.Type.ARRAY,
- expressions=[
- exp.DataType(
- this=exp.DataType.Type[type_token.value],
- expressions=expressions,
- nested=nested,
- )
- ],
- nested=True,
- )
-
- while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
-
- return this
-
- if self._match(TokenType.L_BRACKET):
- self._retreat(index)
- return None
-
+ this: t.Optional[exp.Expression] = None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
+
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
@@ -3156,23 +3168,35 @@ class Parser(metaclass=_Parser):
values = self._parse_csv(self._parse_conjunction)
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
- value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match_text_seq("WITH", "TIME", "ZONE"):
maybe_func = False
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
+ tz_type = (
+ exp.DataType.Type.TIMETZ
+ if type_token in self.TIMES
+ else exp.DataType.Type.TIMESTAMPTZ
+ )
+ this = exp.DataType(this=tz_type, expressions=expressions)
elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
maybe_func = False
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
+ this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
maybe_func = False
elif type_token == TokenType.INTERVAL:
- unit = self._parse_var()
+ if self._match_text_seq("YEAR", "TO", "MONTH"):
+ span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()]
+ elif self._match_text_seq("DAY", "TO", "SECOND"):
+ span = [exp.IntervalDayToSecondSpan()]
+ else:
+ span = None
+ unit = not span and self._parse_var()
if not unit:
- value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL)
+ this = self.expression(
+ exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
+ )
else:
- value = self.expression(exp.Interval, unit=unit)
+ this = self.expression(exp.Interval, unit=unit)
if maybe_func and check_func:
index2 = self._index
@@ -3184,16 +3208,19 @@ class Parser(metaclass=_Parser):
self._retreat(index2)
- if value:
- return value
+ if not this:
+ this = exp.DataType(
+ this=exp.DataType.Type[type_token.value],
+ expressions=expressions,
+ nested=nested,
+ values=values,
+ prefix=prefix,
+ )
- return exp.DataType(
- this=exp.DataType.Type[type_token.value],
- expressions=expressions,
- nested=nested,
- values=values,
- prefix=prefix,
- )
+ while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
+ this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
+
+ return this
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
this = self._parse_type() or self._parse_id_var()
@@ -3738,6 +3765,7 @@ class Parser(metaclass=_Parser):
ifs = []
default = None
+ comments = self._prev_comments
expression = self._parse_conjunction()
while self._match(TokenType.WHEN):
@@ -3753,7 +3781,7 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected END after CASE", self._prev)
return self._parse_window(
- self.expression(exp.Case, this=expression, ifs=ifs, default=default)
+ self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
)
def _parse_if(self) -> t.Optional[exp.Expression]:
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 7a3c88b..f028f5a 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -372,21 +372,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
is_table: bool = False,
normalize: t.Optional[bool] = None,
) -> str:
- dialect = dialect or self.dialect
- normalize = self.normalize if normalize is None else normalize
-
- try:
- identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
- except ParseError:
- return name if isinstance(name, str) else name.name
-
- name = identifier.name
- if not normalize:
- return name
-
- # This can be useful for normalize_identifier
- identifier.meta["is_table"] = is_table
- return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
+ return normalize_name(
+ name,
+ dialect=dialect or self.dialect,
+ is_table=is_table,
+ normalize=self.normalize if normalize is None else normalize,
+ )
def depth(self) -> int:
if not self.empty and not self._depth:
@@ -418,6 +409,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return self._type_mapping_cache[schema_type]
+def normalize_name(
+ name: str | exp.Identifier,
+ dialect: DialectType = None,
+ is_table: bool = False,
+ normalize: t.Optional[bool] = True,
+) -> str:
+ try:
+ identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
+ except ParseError:
+ return name if isinstance(name, str) else name.name
+
+ name = identifier.name
+ if not normalize:
+ return name
+
+ # This can be useful for normalize_identifier
+ identifier.meta["is_table"] = is_table
+ return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
+
+
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 81bcc0b..d278dbf 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -110,6 +110,7 @@ class TokenType(AutoName):
JSON = auto()
JSONB = auto()
TIME = auto()
+ TIMETZ = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@@ -151,6 +152,11 @@ class TokenType(AutoName):
IPADDRESS = auto()
IPPREFIX = auto()
ENUM = auto()
+ ENUM8 = auto()
+ ENUM16 = auto()
+ FIXEDSTRING = auto()
+ LOWCARDINALITY = auto()
+ NESTED = auto()
# keywords
ALIAS = auto()
@@ -659,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TINYINT": TokenType.TINYINT,
"SHORT": TokenType.SMALLINT,
"SMALLINT": TokenType.SMALLINT,
+ "INT128": TokenType.INT128,
"INT2": TokenType.SMALLINT,
"INTEGER": TokenType.INT,
"INT": TokenType.INT,
@@ -699,6 +706,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
"TIME": TokenType.TIME,
+ "TIMETZ": TokenType.TIMETZ,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@@ -879,6 +887,11 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
+
+ if self._comments and token_type == TokenType.SEMICOLON and self.tokens:
+ self.tokens[-1].comments.extend(self._comments)
+ self._comments = []
+
self.tokens.append(
Token(
token_type,