diff options
Diffstat (limited to 'sqlglot')
45 files changed, 1983 insertions, 866 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 35feaad..6cf9949 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -22,6 +22,7 @@ from sqlglot.expressions import ( Expression as Expression, alias_ as alias, and_ as and_, + case as case, cast as cast, column as column, condition as condition, @@ -82,8 +83,7 @@ def parse( Returns: The resulting syntax tree collection. """ - dialect = Dialect.get_or_raise(read or dialect)() - return dialect.parse(sql, **opts) + return Dialect.get_or_raise(read or dialect).parse(sql, **opts) @t.overload @@ -117,7 +117,7 @@ def parse_one( The syntax tree for the first parsed statement. """ - dialect = Dialect.get_or_raise(read or dialect)() + dialect = Dialect.get_or_raise(read or dialect) if into: result = dialect.parse_into(into, sql, **opts) @@ -157,7 +157,8 @@ def transpile( The list of transpiled SQL statements. """ write = (read if write is None else write) if identity else write + write = Dialect.get_or_raise(write) return [ - Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else "" + write.generate(expression, copy=False, **opts) if expression else "" for expression in parse(sql, read, error_level=error_level) ] diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index 4a2820b..5a77409 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -81,7 +81,7 @@ if args.parse: ) ] elif args.tokenize: - objs = sqlglot.Dialect.get_or_raise(args.read)().tokenize(sql) + objs = sqlglot.Dialect.get_or_raise(args.read).tokenize(sql) else: objs = sqlglot.transpile( sql, diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index f515608..68d36fe 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -297,27 +297,26 @@ class DataFrame: select_expressions.append(expression_select_pair) # type: ignore return select_expressions - def sql( - self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs - ) -> t.List[str]: + def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: from sqlglot.dataframe.sql.session import SparkSession - if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: - logger.warning( - f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." - ) + dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) + df = self._resolve_pending_hints() select_expressions = df._get_select_expressions() output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} + for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - quote_identifiers(select_expression) + quote_identifiers(select_expression, dialect=dialect) select_expression = t.cast( - exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) + exp.Select, optimize_func(select_expression, dialect=dialect) ) + select_expression = df._replace_cte_names_with_hashes(select_expression) + expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: cache_table_name = df._create_hash_from_expression(select_expression) @@ -330,13 +329,12 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.sql( - dialect=SparkSession().dialect - ) + expression.alias_or_name: expression.type.sql(dialect=dialect) for expression in select_expression.expressions }, - dialect=SparkSession().dialect, + dialect=dialect, ) + cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), @@ -345,6 +343,7 @@ class DataFrame: expression = exp.Cache( this=cache_table, expression=select_expression, lazy=True, options=options ) + # We will drop the "view" if it exists before running the cache table output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) elif expression_type == exp.Create: @@ -355,18 +354,17 @@ class DataFrame: select_without_ctes = select_expression.copy() select_without_ctes.set("with", None) expression.set("expression", select_without_ctes) + if select_expression.ctes: expression.set("with", exp.With(expressions=select_expression.ctes)) elif expression_type == exp.Select: expression = select_expression else: raise ValueError(f"Invalid expression type: {expression_type}") + output_expressions.append(expression) - return [ - expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) - for expression in output_expressions - ] + return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] def copy(self, **kwargs) -> DataFrame: return DataFrame(**object_to_dict(self, **kwargs)) @@ -542,12 +540,7 @@ class DataFrame: """ columns = self._ensure_and_normalize_cols(cols) pre_ordered_col_indexes = [ - x - for x in [ - i if isinstance(col.expression, exp.Ordered) else None - for i, col in enumerate(columns) - ] - if x is not None + i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) ] if ascending is None: ascending = [True] * len(columns) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index a424ea4..6671c5b 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -306,7 +306,7 @@ def collect_list(col: ColumnOrName) -> Column: def collect_set(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.SetAgg) + return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg) def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 531ee17..4a33ef9 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -28,7 +28,7 @@ class SparkSession: self.known_sequence_ids = set() self.name_to_sequence_id_mapping = defaultdict(list) self.incrementing_id = 1 - self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() + self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) def __new__(cls, *args, **kwargs) -> SparkSession: if cls._instance is None: @@ -182,7 +182,7 @@ class SparkSession: def getOrCreate(self) -> SparkSession: spark = SparkSession() - spark.dialect = Dialect.get_or_raise(self.dialect)() + spark.dialect = Dialect.get_or_raise(self.dialect) return spark @classproperty diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index fc9a3ae..2a9dde9 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, arg_max_or_min_no_count, binary_from_function, date_add_interval_sql, @@ -23,6 +24,7 @@ from sqlglot.dialects.dialect import ( regexp_replace_sql, rename_func, timestrtotime_sql, + ts_or_ds_add_cast, ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get, split_num_words @@ -174,6 +176,44 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg) +def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str: + return self.sql( + exp.Exists( + this=exp.select("1") + .from_(exp.Unnest(expressions=[expression.left]).as_("_unnest", table=["_col"])) + .where(exp.column("_col").eq(expression.right)) + ) + ) + + +def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> str: + return date_add_interval_sql("DATE", "ADD")(self, ts_or_ds_add_cast(expression)) + + +def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str: + expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True)) + expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True)) + unit = expression.args.get("unit") or "DAY" + return self.func("DATE_DIFF", expression.this, expression.expression, unit) + + +def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str: + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in (None, exp.UnixToTime.SECONDS): + return f"TIMESTAMP_SECONDS({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TIMESTAMP_MILLIS({timestamp})" + if scale == exp.UnixToTime.MICROS: + return f"TIMESTAMP_MICROS({timestamp})" + if scale == exp.UnixToTime.NANOS: + # We need to cast to INT64 because that's what BQ expects + return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))" + + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" + + class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True SUPPORTS_USER_DEFINED_TYPES = False @@ -181,7 +221,7 @@ class BigQuery(Dialect): LOG_BASE_FIRST = False # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE # bigquery udfs are case sensitive NORMALIZE_FUNCTIONS = False @@ -220,8 +260,7 @@ class BigQuery(Dialect): # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"} - @classmethod - def normalize_identifier(cls, expression: E) -> E: + def normalize_identifier(self, expression: E) -> E: if isinstance(expression, exp.Identifier): parent = expression.parent while isinstance(parent, exp.Dot): @@ -265,7 +304,6 @@ class BigQuery(Dialect): "DECLARE": TokenType.COMMAND, "FLOAT64": TokenType.DOUBLE, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, - "INT64": TokenType.BIGINT, "MODEL": TokenType.MODEL, "NOT DETERMINISTIC": TokenType.VOLATILE, "RECORD": TokenType.STRUCT, @@ -316,6 +354,15 @@ class BigQuery(Dialect): "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), + "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.UnixToTime.MICROS + ), + "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS + ), + "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS + ), "TO_JSON_STRING": exp.JSONFormat.from_arg_list, } @@ -358,6 +405,24 @@ class BigQuery(Dialect): NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN} + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.END: lambda self: self._parse_as_command(self._prev), + TokenType.FOR: lambda self: self._parse_for_in(), + } + + BRACKET_OFFSETS = { + "OFFSET": (0, False), + "ORDINAL": (1, False), + "SAFE_OFFSET": (0, True), + "SAFE_ORDINAL": (1, True), + } + + def _parse_for_in(self) -> exp.ForIn: + this = self._parse_range() + self._match_text_seq("DO") + return self.expression(exp.ForIn, this=this, expression=self._parse_statement()) + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: this = super()._parse_table_part(schema=schema) or self._parse_number() @@ -419,6 +484,26 @@ class BigQuery(Dialect): return json_object + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + bracket = super()._parse_bracket(this) + + if this is bracket: + return bracket + + if isinstance(bracket, exp.Bracket): + for expression in bracket.expressions: + name = expression.name.upper() + + if name not in self.BRACKET_OFFSETS: + break + + offset, safe = self.BRACKET_OFFSETS[name] + bracket.set("offset", offset) + bracket.set("safe", safe) + expression.replace(expression.expressions[0]) + + return bracket + class Generator(generator.Generator): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False @@ -430,12 +515,14 @@ class BigQuery(Dialect): NVL2_SUPPORTED = False UNNEST_WITH_ORDINALITY = False COLLATE_IS_FUNC = True + LIMIT_ONLY_LITERALS = True TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), + exp.ArrayContains: _array_contains_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}" @@ -498,10 +585,13 @@ class BigQuery(Dialect): exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), - exp.TsOrDsAdd: date_add_interval_sql("DATE", "ADD"), + exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsDiff: _ts_or_ds_diff_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.Unhex: rename_func("FROM_HEX"), + exp.UnixToTime: _unix_to_time_sql, exp.Values: _derived_table_values_to_unnest, exp.VariancePop: rename_func("VAR_POP"), } @@ -671,6 +761,23 @@ class BigQuery(Dialect): return inline_array_sql(self, expression) + def bracket_sql(self, expression: exp.Bracket) -> str: + expressions = expression.expressions + expressions_sql = ", ".join(self.sql(e) for e in expressions) + offset = expression.args.get("offset") + + if offset == 0: + expressions_sql = f"OFFSET({expressions_sql})" + elif offset == 1: + expressions_sql = f"ORDINAL({expressions_sql})" + else: + self.unsupported(f"Unsupported array offset: {offset}") + + if expression.args.get("safe"): + expressions_sql = f"SAFE_{expressions_sql}" + + return f"{self.sql(expression, 'this')}[{expressions_sql}]" + def transaction_sql(self, *_) -> str: return "BEGIN TRANSACTION" diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 394a922..da182aa 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -35,8 +35,8 @@ def _quantile_sql(self, e): class ClickHouse(Dialect): NORMALIZE_FUNCTIONS: bool | str = False NULL_ORDERING = "nulls_are_last" - STRICT_STRING_CONCAT = True SUPPORTS_USER_DEFINED_TYPES = False + SAFE_DIVISION = True ESCAPE_SEQUENCES = { "\\0": "\0", @@ -63,11 +63,7 @@ class ClickHouse(Dialect): "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, "GLOBAL": TokenType.GLOBAL, - "INT16": TokenType.SMALLINT, "INT256": TokenType.INT256, - "INT32": TokenType.INT, - "INT64": TokenType.BIGINT, - "INT8": TokenType.TINYINT, "LOWCARDINALITY": TokenType.LOWCARDINALITY, "MAP": TokenType.MAP, "NESTED": TokenType.NESTED, @@ -112,6 +108,7 @@ class ClickHouse(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + "ARRAYJOIN": lambda self: self.expression(exp.Explode, this=self._parse_expression()), "QUANTILE": lambda self: self._parse_quantile(), } @@ -223,12 +220,13 @@ class ClickHouse(Dialect): except ParseError: # WITH <expression> AS <identifier> self._retreat(index) - statement = self._parse_statement() - if statement and isinstance(statement.this, exp.Alias): - self.raise_error("Expected CTE to have alias") - - return self.expression(exp.CTE, this=statement, alias=statement and statement.this) + return self.expression( + exp.CTE, + this=self._parse_field(), + alias=self._parse_table_alias(), + scalar=True, + ) def _parse_join_parts( self, @@ -385,9 +383,11 @@ class ClickHouse(Dialect): exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), + exp.Explode: rename_func("arrayJoin"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.IsNan: rename_func("isNaN"), exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), + exp.Nullif: rename_func("nullIf"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, @@ -459,19 +459,11 @@ class ClickHouse(Dialect): return super().datatype_sql(expression) - def safeconcat_sql(self, expression: exp.SafeConcat) -> str: - # Clickhouse errors out if we try to cast a NULL value to TEXT - return self.func( - "CONCAT", - *[ - exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text")) - for e in t.cast(t.List[exp.Condition], expression.expressions) - ], - ) - def cte_sql(self, expression: exp.CTE) -> str: - if isinstance(expression.this, exp.Alias): - return self.sql(expression, "this") + if expression.args.get("scalar"): + this = self.sql(expression, "this") + alias = self.sql(expression, "alias") + return f"{this} AS {alias}" return super().cte_sql(expression) diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index b777db0..1c10a8b 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,13 +1,18 @@ from __future__ import annotations from sqlglot import exp, transforms -from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql +from sqlglot.dialects.dialect import ( + date_delta_sql, + parse_date_delta, + timestamptrunc_sql, +) from sqlglot.dialects.spark import Spark -from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql from sqlglot.tokens import TokenType class Databricks(Spark): + SAFE_DIVISION = False + class Parser(Spark.Parser): LOG_DEFAULTS_TO_LN = True STRICT_CAST = True @@ -27,8 +32,8 @@ class Databricks(Spark): class Generator(Spark.Generator): TRANSFORMS = { **Spark.Generator.TRANSFORMS, - exp.DateAdd: generate_date_delta_with_unit_sql, - exp.DateDiff: generate_date_delta_with_unit_sql, + exp.DateAdd: date_delta_sql("DATEADD"), + exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DatetimeAdd: lambda self, e: self.func( "TIMESTAMPADD", e.text("unit"), e.expression, e.this ), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 21e7889..c7cea64 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,14 +1,14 @@ from __future__ import annotations import typing as t -from enum import Enum +from enum import Enum, auto from functools import reduce from sqlglot import exp from sqlglot._typing import E from sqlglot.errors import ParseError from sqlglot.generator import Generator -from sqlglot.helper import flatten, seq_get +from sqlglot.helper import AutoName, flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType @@ -16,6 +16,9 @@ from sqlglot.trie import new_trie B = t.TypeVar("B", bound=exp.Binary) +DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] +DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] + class Dialects(str, Enum): DIALECT = "" @@ -43,6 +46,15 @@ class Dialects(str, Enum): Doris = "doris" +class NormalizationStrategy(str, AutoName): + """Specifies the strategy according to which identifiers should be normalized.""" + + LOWERCASE = auto() # Unquoted identifiers are lowercased + UPPERCASE = auto() # Unquoted identifiers are uppercased + CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes + CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes + + class _Dialect(type): classes: t.Dict[str, t.Type[Dialect]] = {} @@ -106,26 +118,8 @@ class _Dialect(type): klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) - dialect_properties = { - **{ - k: v - for k, v in vars(klass).items() - if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") - }, - "TOKENIZER_CLASS": klass.tokenizer_class, - } - if enum not in ("", "bigquery"): - dialect_properties["SELECT_KINDS"] = () - - # Pass required dialect properties to the tokenizer, parser and generator classes - for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): - for name, value in dialect_properties.items(): - if hasattr(subclass, name): - setattr(subclass, name, value) - - if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: - klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + klass.generator_class.SELECT_KINDS = () if not klass.SUPPORTS_SEMI_ANTI_JOIN: klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { @@ -133,8 +127,6 @@ class _Dialect(type): TokenType.SEMI, } - klass.generator_class.can_identify = klass.can_identify - return klass @@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect): # Determines whether or not the table alias comes after tablesample ALIAS_POST_TABLESAMPLE = False - # Determines whether or not unquoted identifiers are resolved as uppercase - # When set to None, it means that the dialect treats all identifiers as case-insensitive - RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False + # Specifies the strategy according to which identifiers should be normalized. + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE # Determines whether or not an unquoted identifier can start with a digit IDENTIFIERS_CAN_START_WITH_DIGIT = False @@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect): # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" NULL_ORDERING = "nulls_are_small" + # Whether the behavior of a / b depends on the types of a and b. + # False means a / b is always float division. + # True means a / b is integer division if both a and b are integers. + TYPED_DIVISION = False + + # False means 1 / 0 throws an error. + # True means 1 / 0 returns null. + SAFE_DIVISION = False + + # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string + CONCAT_COALESCE = False + DATE_FORMAT = "'%Y-%m-%d'" DATEINT_FORMAT = "'%Y%m%d'" TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" @@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect): # Such columns may be excluded from SELECT * queries, for example PSEUDOCOLUMNS: t.Set[str] = set() - # Autofilled + # --- Autofilled --- + tokenizer_class = Tokenizer parser_class = Parser generator_class = Generator @@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect): INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} - def __eq__(self, other: t.Any) -> bool: - return type(self) == other + # Delimiters for quotes, identifiers and the corresponding escape characters + QUOTE_START = "'" + QUOTE_END = "'" + IDENTIFIER_START = '"' + IDENTIFIER_END = '"' - def __hash__(self) -> int: - return hash(type(self)) + # Delimiters for bit, hex and byte literals + BIT_START: t.Optional[str] = None + BIT_END: t.Optional[str] = None + HEX_START: t.Optional[str] = None + HEX_END: t.Optional[str] = None + BYTE_START: t.Optional[str] = None + BYTE_END: t.Optional[str] = None @classmethod - def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: + def get_or_raise(cls, dialect: DialectType) -> Dialect: + """ + Look up a dialect in the global dialect registry and return it if it exists. + + Args: + dialect: The target dialect. If this is a string, it can be optionally followed by + additional key-value pairs that are separated by commas and are used to specify + dialect settings, such as whether the dialect's identifiers are case-sensitive. + + Example: + >>> dialect = dialect_class = get_or_raise("duckdb") + >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") + + Returns: + The corresponding Dialect instance. + """ + if not dialect: - return cls + return cls() if isinstance(dialect, _Dialect): - return dialect + return dialect() if isinstance(dialect, Dialect): - return dialect.__class__ + return dialect + if isinstance(dialect, str): + try: + dialect_name, *kv_pairs = dialect.split(",") + kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} + except ValueError: + raise ValueError( + f"Invalid dialect format: '{dialect}'. " + "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." + ) + + result = cls.get(dialect_name.strip()) + if not result: + raise ValueError(f"Unknown dialect '{dialect_name}'.") - result = cls.get(dialect) - if not result: - raise ValueError(f"Unknown dialect '{dialect}'") + return result(**kwargs) - return result + raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") @classmethod def format_time( @@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect): return expression - @classmethod - def normalize_identifier(cls, expression: E) -> E: + def __init__(self, **kwargs) -> None: + normalization_strategy = kwargs.get("normalization_strategy") + + if normalization_strategy is None: + self.normalization_strategy = self.NORMALIZATION_STRATEGY + else: + self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) + + def __eq__(self, other: t.Any) -> bool: + # Does not currently take dialect state into account + return type(self) == other + + def __hash__(self) -> int: + # Does not currently take dialect state into account + return hash(type(self)) + + def normalize_identifier(self, expression: E) -> E: """ - Normalizes an unquoted identifier to either lower or upper case, thus essentially - making it case-insensitive. If a dialect treats all identifiers as case-insensitive, - they will be normalized to lowercase regardless of being quoted or not. + Transforms an identifier in a way that resembles how it'd be resolved by this dialect. + + For example, an identifier like FoO would be resolved as foo in Postgres, because it + lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so + it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, + and so any normalization would be prohibited in order to avoid "breaking" the identifier. + + There are also dialects like Spark, which are case-insensitive even when quotes are + present, and dialects like MySQL, whose resolution rules match those employed by the + underlying operating system, for example they may always be case-sensitive in Linux. + + Finally, the normalization behavior of some engines can even be controlled through flags, + like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. + + SQLGlot aims to understand and handle all of these different behaviors gracefully, so + that it can analyze queries in the optimizer and successfully capture their semantics. """ - if isinstance(expression, exp.Identifier) and ( - not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None + if ( + isinstance(expression, exp.Identifier) + and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE + and ( + not expression.quoted + or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE + ) ): expression.set( "this", expression.this.upper() - if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE + if self.normalization_strategy is NormalizationStrategy.UPPERCASE else expression.this.lower(), ) return expression - @classmethod - def case_sensitive(cls, text: str) -> bool: + def case_sensitive(self, text: str) -> bool: """Checks if text contains any case sensitive characters, based on the dialect's rules.""" - if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: + if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: return False - unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + unsafe = ( + str.islower + if self.normalization_strategy is NormalizationStrategy.UPPERCASE + else str.isupper + ) return any(unsafe(char) for char in text) - @classmethod - def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: + def can_identify(self, text: str, identify: str | bool = "safe") -> bool: """Checks if text can be identified given an identify option. Args: @@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect): return True if identify == "safe": - return not cls.case_sensitive(text) + return not self.case_sensitive(text) return False - @classmethod - def quote_identifier(cls, expression: E, identify: bool = True) -> E: + def quote_identifier(self, expression: E, identify: bool = True) -> E: if isinstance(expression, exp.Identifier): name = expression.this expression.set( "quoted", - identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), + identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), ) return expression @@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect): @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() + self._tokenizer = self.tokenizer_class(dialect=self) return self._tokenizer def parser(self, **opts) -> Parser: - return self.parser_class(**opts) + return self.parser_class(dialect=self, **opts) def generator(self, **opts) -> Generator: - return self.generator_class(**opts) + return self.generator_class(dialect=self, **opts) DialectType = t.Union[str, Dialect, t.Type[Dialect], None] @@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: return _ts_or_ds_to_date_sql -def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: +def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) @@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex return self.func(name, expression.this, expression.expression) return _arg_max_or_min_sql + + +def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: + this = expression.this.copy() + + return_type = expression.return_type + if return_type.is_type(exp.DataType.Type.DATE): + # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we + # can truncate timestamp strings, because some dialects can't cast them to DATE + this = exp.cast(this, exp.DataType.Type.TIMESTAMP) + + expression.this.replace(exp.cast(this, return_type)) + return expression + + +def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: + def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: + if cast and isinstance(expression, exp.TsOrDsAdd): + expression = ts_or_ds_add_cast(expression) + + return self.func( + name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this + ) + + return _delta_sql diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index bd7e0f2..11af17b 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -19,6 +19,7 @@ class Doris(MySQL): class Parser(MySQL.Parser): FUNCTIONS = { **MySQL.Parser.FUNCTIONS, + "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, "DATE_TRUNC": parse_timestamp_trunc, "REGEXP": exp.RegexpLike.from_arg_list, } @@ -47,7 +48,7 @@ class Doris(MySQL): exp.JSONExtract: arrow_json_extract_sql, exp.RegexpLike: rename_func("REGEXP"), exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), - exp.SetAgg: rename_func("COLLECT_SET"), + exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Split: rename_func("SPLIT_BY_STRING"), exp.TimeStrToDate: rename_func("TO_DATE"), diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 42453fd..70c96f8 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -43,6 +43,8 @@ class Drill(Dialect): TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_SEMI_ANTI_JOIN = False + TYPED_DIVISION = True + CONCAT_COALESCE = True TIME_MAPPING = { "y": "%Y", @@ -83,7 +85,6 @@ class Drill(Dialect): class Parser(parser.Parser): STRICT_CAST = False - CONCAT_NULL_OUTPUTS_STRING = True FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d8d9f90..b94e3a6 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -2,9 +2,10 @@ from __future__ import annotations import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, arrow_json_extract_scalar_sql, @@ -36,7 +37,8 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + interval = self.sql(exp.Interval(this=expression.expression, unit=unit)) + return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}" def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: @@ -84,7 +86,8 @@ def _parse_date_diff(args: t.List) -> exp.Expression: def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: args = [ - f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions + f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}" + for e in expression.expressions ] return f"{{{', '.join(args)}}}" @@ -105,17 +108,35 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str: return f"CAST({sql} AS TEXT)" +def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str: + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in (None, exp.UnixToTime.SECONDS): + return f"TO_TIMESTAMP({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"EPOCH_MS({timestamp})" + if scale == exp.UnixToTime.MICROS: + return f"MAKE_TIMESTAMP({timestamp})" + if scale == exp.UnixToTime.NANOS: + return f"TO_TIMESTAMP({timestamp} / 1000000000)" + + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" + + class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" SUPPORTS_USER_DEFINED_TYPES = False + SAFE_DIVISION = True + INDEX_OFFSET = 1 + CONCAT_COALESCE = True # https://duckdb.org/docs/sql/introduction.html#creating-a-new-table - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - ":=": TokenType.EQ, "//": TokenType.DIV, "ATTACH": TokenType.COMMAND, "BINARY": TokenType.VARBINARY, @@ -124,8 +145,6 @@ class DuckDB(Dialect): "CHAR": TokenType.TEXT, "CHARACTER VARYING": TokenType.TEXT, "EXCLUDE": TokenType.EXCEPT, - "HUGEINT": TokenType.INT128, - "INT1": TokenType.TINYINT, "LOGICAL": TokenType.BOOLEAN, "PIVOT_WIDER": TokenType.PIVOT, "SIGNED": TokenType.INT, @@ -141,8 +160,6 @@ class DuckDB(Dialect): } class Parser(parser.Parser): - CONCAT_NULL_OUTPUTS_STRING = True - BITWISE = { **parser.Parser.BITWISE, TokenType.TILDA: exp.RegexpLike, @@ -150,6 +167,7 @@ class DuckDB(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "ARRAY_HAS": exp.ArrayContains.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_REVERSE_SORT": _sort_array_reverse, @@ -157,13 +175,23 @@ class DuckDB(Dialect): "DATE_DIFF": _parse_date_diff, "DATE_TRUNC": date_trunc_to_time, "DATETRUNC": date_trunc_to_time, + "DECODE": lambda args: exp.Decode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), + "ENCODE": lambda args: exp.Encode( + this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + ), "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( - this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000)) + this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), + "LIST_HAS": exp.ArrayContains.from_arg_list, "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "MAKE_TIMESTAMP": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.UnixToTime.MICROS + ), "MEDIAN": lambda args: exp.PercentileCont( this=seq_get(args, 0), expression=exp.Literal.number(0.5) ), @@ -192,15 +220,8 @@ class DuckDB(Dialect): "XOR": binary_from_function(exp.BitwiseXor), } - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - "DECODE": lambda self: self.expression( - exp.Decode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8") - ), - "ENCODE": lambda self: self.expression( - exp.Encode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8") - ), - } + FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() + FUNCTION_PARSERS.pop("DECODE", None) TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.SEMI, @@ -277,6 +298,7 @@ class DuckDB(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False), exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), + exp.IsInf: rename_func("ISINF"), exp.IsNan: rename_func("ISNAN"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, @@ -294,6 +316,9 @@ class DuckDB(Dialect): exp.ParseJSON: rename_func("JSON"), exp.PercentileCont: rename_func("QUANTILE_CONT"), exp.PercentileDisc: rename_func("QUANTILE_DISC"), + # DuckDB doesn't allow qualified columns inside of PIVOT expressions. + # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62 + exp.Pivot: transforms.preprocess([transforms.unqualify_columns]), exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, exp.RegexpReplace: lambda self, e: self.func( @@ -322,9 +347,15 @@ class DuckDB(Dialect): exp.TimeToUnix: rename_func("EPOCH"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsDiff: lambda self, e: self.func( + "DATE_DIFF", + f"'{e.args.get('unit') or 'day'}'", + exp.cast(e.expression, "TIMESTAMP"), + exp.cast(e.this, "TIMESTAMP"), + ), exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"), exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", - exp.UnixToTime: rename_func("TO_TIMESTAMP"), + exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", exp.VariancePop: rename_func("VAR_POP"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 3b1c8de..0723e37 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -4,10 +4,13 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( + DATE_ADD_OR_SUB, Dialect, + NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, create_with_partitions_sql, + datestrtodate_sql, format_time_lambda, if_sql, is_parse_json, @@ -76,7 +79,10 @@ def _create_sql(self, expression: exp.Create) -> str: return create_with_partitions_sql(self, expression) -def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str: + if isinstance(expression, exp.TsOrDsAdd) and not expression.unit: + return self.func("DATE_ADD", expression.this, expression.expression) + unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) @@ -95,7 +101,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) - return self.func(func, expression.this, modified_increment) -def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: +def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff | exp.TsOrDsDiff) -> str: unit = expression.text("unit").upper() factor = TIME_DIFF_FACTOR.get(unit) @@ -111,25 +117,31 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" - if months_between: - # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part - diff_sql = f"CAST({diff_sql} AS INT)" + if months_between or multiplier_sql: + # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part. + # For the same reason, we want to truncate if there's a divisor present. + diff_sql = f"CAST({diff_sql}{multiplier_sql} AS INT)" - return f"{diff_sql}{multiplier_sql}" + return diff_sql def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: this = expression.this - if is_parse_json(this) and this.this.is_string: - # Since FROM_JSON requires a nested type, we always wrap the json string with - # an array to ensure that "naked" strings like "'a'" will be handled correctly - wrapped_json = exp.Literal.string(f"[{this.this.name}]") - from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json)) - to_json = self.func("TO_JSON", from_json) + if is_parse_json(this): + if this.this.is_string: + # Since FROM_JSON requires a nested type, we always wrap the json string with + # an array to ensure that "naked" strings like "'a'" will be handled correctly + wrapped_json = exp.Literal.string(f"[{this.this.name}]") + + from_json = self.func( + "FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json) + ) + to_json = self.func("TO_JSON", from_json) - # This strips the [, ] delimiters of the dummy array printed by TO_JSON - return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1") + # This strips the [, ] delimiters of the dummy array printed by TO_JSON + return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1") + return self.sql(this) return self.func("TO_JSON", this, expression.args.get("options")) @@ -175,6 +187,8 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): return f"TO_DATE({this}, {time_format})" + if isinstance(expression.this, exp.TsOrDsToDate): + return this return f"TO_DATE({this})" @@ -182,9 +196,10 @@ class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True SUPPORTS_USER_DEFINED_TYPES = False + SAFE_DIVISION = True # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_MAPPING = { "y": "%Y", @@ -241,10 +256,10 @@ class Hive(Dialect): "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, "MSCK REPAIR": TokenType.COMMAND, - "REFRESH": TokenType.COMMAND, - "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, + "REFRESH": TokenType.REFRESH, "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT, "VERSION AS OF": TokenType.VERSION_SNAPSHOT, + "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } NUMERIC_LITERALS = { @@ -264,7 +279,7 @@ class Hive(Dialect): **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, - "COLLECT_SET": exp.SetAgg.from_arg_list, + "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), @@ -411,7 +426,13 @@ class Hive(Dialect): INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False - SUPPORTS_NESTED_CTES = False + + EXPRESSIONS_WITHOUT_NESTED_CTES = { + exp.Insert, + exp.Select, + exp.Subquery, + exp.Union, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -445,7 +466,7 @@ class Hive(Dialect): exp.With: no_recursive_cte_sql, exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, - exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _add_date_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", @@ -477,7 +498,7 @@ class Hive(Dialect): exp.Right: right_to_substring_sql, exp.SafeDivide: no_safe_divide_sql, exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), - exp.SetAgg: rename_func("COLLECT_SET"), + exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -491,7 +512,8 @@ class Hive(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToBase64: rename_func("BASE64"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.TsOrDsAdd: _add_date_sql, + exp.TsOrDsDiff: _date_diff_sql, exp.TsOrDsToDate: _to_date_sql, exp.TryCast: no_trycast_sql, exp.UnixToStr: lambda self, e: self.func( @@ -571,6 +593,8 @@ class Hive(Dialect): and not expression.expressions ): expression = exp.DataType.build("text") + elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions: + expression.set("this", exp.DataType.Type.VARCHAR) elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) elif expression.is_type("float"): diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index c78aa9e..cfc6e83 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, arrow_json_extract_scalar_sql, date_add_interval_sql, datestrtodate_sql, @@ -150,10 +151,18 @@ class MySQL(Dialect): # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html IDENTIFIERS_CAN_START_WITH_DIGIT = True + # We default to treating all identifiers as case-sensitive, since it matches MySQL's + # behavior on Linux systems. For MacOS and Windows systems, one can override this + # setting by specifying `dialect="mysql, normalization_strategy = lowercase"`. + # + # See also https://dev.mysql.com/doc/refman/8.2/en/identifier-case-sensitivity.html + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_SENSITIVE + TIME_FORMAT = "'%Y-%m-%d %T'" DPIPE_IS_STRING_CONCAT = False SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_SEMI_ANTI_JOIN = False + SAFE_DIVISION = True # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions TIME_MAPPING = { @@ -264,11 +273,6 @@ class MySQL(Dialect): TokenType.DPIPE: exp.Or, } - # MySQL uses || as a synonym to the logical OR operator - # https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or - BITWISE = parser.Parser.BITWISE.copy() - BITWISE.pop(TokenType.DPIPE) - TABLE_ALIAS_TOKENS = ( parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS ) @@ -451,7 +455,7 @@ class MySQL(Dialect): self, kind: t.Optional[str] = None ) -> exp.IndexColumnConstraint: if kind: - self._match_texts({"INDEX", "KEY"}) + self._match_texts(("INDEX", "KEY")) this = self._parse_id_var(any_token=False) index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text @@ -514,7 +518,7 @@ class MySQL(Dialect): log = self._parse_string() if self._match_text_seq("IN") else None - if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: + if this in ("BINLOG EVENTS", "RELAYLOG EVENTS"): position = self._parse_number() if self._match_text_seq("FROM") else None db = None else: @@ -671,6 +675,7 @@ class MySQL(Dialect): exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.Week: _remove_ts_or_ds_to_date(), exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), @@ -763,7 +768,7 @@ class MySQL(Dialect): target = self.sql(expression, "target") target = f" {target}" if target else "" - if expression.name in {"COLUMNS", "INDEX"}: + if expression.name in ("COLUMNS", "INDEX"): target = f" FROM{target}" elif expression.name == "GRANTS": target = f" FOR{target}" @@ -796,6 +801,14 @@ class MySQL(Dialect): return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" + def altercolumn_sql(self, expression: exp.AlterColumn) -> str: + dtype = self.sql(expression, "dtype") + if not dtype: + return super().altercolumn_sql(expression) + + this = self.sql(expression, "this") + return f"MODIFY COLUMN {this} {dtype}" + def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str: sql = self.sql(expression, arg) return f" {prefix} {sql}" if sql else "" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 6bdd8d6..51dbd53 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -3,7 +3,14 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql +from sqlglot.dialects.dialect import ( + Dialect, + NormalizationStrategy, + format_time_lambda, + no_ilike_sql, + rename_func, + trim_sql, +) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -30,12 +37,25 @@ def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable: return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref) +def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar: + this = seq_get(args, 0) + + if this and not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + annotate_types(this) + if this.is_type(*exp.DataType.TEMPORAL_TYPES): + return format_time_lambda(exp.TimeToStr, "oracle", default=True)(args) + + return exp.ToChar.from_arg_list(args) + + class Oracle(Dialect): ALIAS_POST_TABLESAMPLE = True LOCKING_READS_SUPPORTED = True # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm - RESOLVES_IDENTIFIERS_AS_UPPERCASE = True + NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes @@ -64,11 +84,13 @@ class Oracle(Dialect): } class Parser(parser.Parser): + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} FUNCTIONS = { **parser.Parser.FUNCTIONS, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), + "TO_CHAR": to_char, } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -130,6 +152,7 @@ class Oracle(Dialect): TABLE_HINTS = False COLUMN_JOIN_MARKS_SUPPORTED = True DATA_TYPE_SPECIFIERS_ALLOWED = True + ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" @@ -192,6 +215,12 @@ class Oracle(Dialect): ) return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}" + def add_column_sql(self, expression: exp.AlterTable) -> str: + actions = self.expressions(expression, key="actions", flat=True) + if len(expression.args.get("actions", [])) > 1: + return f"ADD ({actions})" + return f"ADD {actions}" + class Tokenizer(tokens.Tokenizer): VAR_SINGLE_TOKENS = {"@", "$", "#"} diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 27c6851..fefddee 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( + DATE_ADD_OR_SUB, Dialect, any_value_to_max_sql, arrow_json_extract_scalar_sql, @@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import ( timestamptrunc_sql, timestrtotime_sql, trim_sql, + ts_or_ds_add_cast, ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get @@ -41,8 +43,11 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB], str]: + def func(self: Postgres.Generator, expression: DATE_ADD_OR_SUB) -> str: + if isinstance(expression, exp.TsOrDsAdd): + expression = ts_or_ds_add_cast(expression) + this = self.sql(expression, "this") unit = expression.args.get("unit") @@ -60,8 +65,8 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) - end = f"CAST({expression.this} AS TIMESTAMP)" - start = f"CAST({expression.expression} AS TIMESTAMP)" + end = f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" + start = f"CAST({self.sql(expression, 'expression')} AS TIMESTAMP)" if factor is not None: return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)" @@ -69,7 +74,7 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: age = f"AGE({end}, {start})" if unit == "WEEK": - unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + unit = f"EXTRACT(days FROM ({end} - {start})) / 7" elif unit == "MONTH": unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" elif unit == "QUARTER": @@ -183,37 +188,43 @@ def _to_timestamp(args: t.List) -> exp.Expression: return format_time_lambda(exp.StrToTime, "postgres")(args) -def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: - """Remove table refs from columns in when statements.""" - if isinstance(expression, exp.Merge): - alias = expression.this.args.get("alias") +def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str: + def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: + """Remove table refs from columns in when statements.""" + if isinstance(expression, exp.Merge): + alias = expression.this.args.get("alias") - normalize = ( - lambda identifier: Postgres.normalize_identifier(identifier).name - if identifier - else None - ) + normalize = ( + lambda identifier: self.dialect.normalize_identifier(identifier).name + if identifier + else None + ) - targets = {normalize(expression.this.this)} + targets = {normalize(expression.this.this)} - if alias: - targets.add(normalize(alias.this)) + if alias: + targets.add(normalize(alias.this)) - for when in expression.expressions: - when.transform( - lambda node: exp.column(node.this) - if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets - else node, - copy=False, - ) + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node, + copy=False, + ) - return expression + return expression + + return transforms.preprocess([_remove_target_from_merge])(self, expression) class Postgres(Dialect): INDEX_OFFSET = 1 + TYPED_DIVISION = True + CONCAT_COALESCE = True NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" + TIME_MAPPING = { "AM": "%p", "PM": "%p", @@ -263,6 +274,7 @@ class Postgres(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "CHARACTER VARYING": TokenType.VARCHAR, + "CONSTRAINT TRIGGER": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, @@ -277,6 +289,7 @@ class Postgres(Dialect): "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, "OID": TokenType.OBJECT_IDENTIFIER, + "OPERATOR": TokenType.OPERATOR, "REGCLASS": TokenType.OBJECT_IDENTIFIER, "REGCOLLATION": TokenType.OBJECT_IDENTIFIER, "REGCONFIG": TokenType.OBJECT_IDENTIFIER, @@ -298,8 +311,6 @@ class Postgres(Dialect): VAR_SINGLE_TOKENS = {"$"} class Parser(parser.Parser): - CONCAT_NULL_OUTPUTS_STRING = True - FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, @@ -326,12 +337,13 @@ class Postgres(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, + TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), TokenType.DAT: lambda self, this: self.expression( exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this] ), - TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), + TokenType.OPERATOR: lambda self, this: self._parse_operator(this), } STATEMENT_PARSERS = { @@ -339,11 +351,28 @@ class Postgres(Dialect): TokenType.END: lambda self: self._parse_commit_or_rollback(), } - def _parse_factor(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_exponent, self.FACTOR) + def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + while True: + if not self._match(TokenType.L_PAREN): + break + + op = "" + while self._curr and not self._match(TokenType.R_PAREN): + op += self._curr.text + self._advance() + + this = self.expression( + exp.Operator, + comments=self._prev_comments, + this=this, + operator=op, + expression=self._parse_bitwise(), + ) + + if not self._match(TokenType.OPERATOR): + break - def _parse_exponent(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_unary, self.EXPONENT) + return this def _parse_date_part(self) -> exp.Expression: part = self._parse_type() @@ -405,7 +434,7 @@ class Postgres(Dialect): exp.Max: max_or_greatest, exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, - exp.Merge: transforms.preprocess([_remove_target_from_merge]), + exp.Merge: _merge_sql, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess( [transforms.add_within_group_for_percentiles] @@ -434,6 +463,8 @@ class Postgres(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, + exp.TsOrDsAdd: _date_add_sql("+"), + exp.TsOrDsDiff: _date_diff_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index ded3655..10a6074 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -5,9 +5,11 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, binary_from_function, bool_xor_sql, date_trunc_to_time, + datestrtodate_sql, encode_decode_sql, format_time_lambda, if_sql, @@ -22,6 +24,7 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, timestamptrunc_sql, timestrtotime_sql, + ts_or_ds_add_cast, ) from sqlglot.dialects.mysql import MySQL from sqlglot.helper import apply_index_offset, seq_get @@ -95,17 +98,16 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: - this = expression.this + expression = ts_or_ds_add_cast(expression) + unit = exp.Literal.string(expression.text("unit") or "day") + return self.func("DATE_ADD", unit, expression.expression, expression.this) - if not isinstance(this, exp.CurrentDate): - this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE") - return self.func( - "DATE_ADD", - exp.Literal.string(expression.text("unit") or "day"), - expression.expression, - this, - ) +def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: + this = exp.cast(expression.this, "TIMESTAMP") + expr = exp.cast(expression.expression, "TIMESTAMP") + unit = exp.Literal.string(expression.text("unit") or "day") + return self.func("DATE_DIFF", unit, expr, this) def _approx_percentile(args: t.List) -> exp.Expression: @@ -136,11 +138,11 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) -def _parse_element_at(args: t.List) -> exp.SafeBracket: +def _parse_element_at(args: t.List) -> exp.Bracket: this = seq_get(args, 0) index = seq_get(args, 1) assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) - return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1)) + return exp.Bracket(this=this, expressions=[index], offset=1, safe=True) def _unnest_sequence(expression: exp.Expression) -> exp.Expression: @@ -168,6 +170,22 @@ def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> return rename_func("ARBITRARY")(self, expression) +def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str: + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in (None, exp.UnixToTime.SECONDS): + return rename_func("FROM_UNIXTIME")(self, expression) + if scale == exp.UnixToTime.MILLIS: + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)" + if scale == exp.UnixToTime.MICROS: + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)" + if scale == exp.UnixToTime.NANOS: + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)" + + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" + + class Presto(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_last" @@ -175,11 +193,12 @@ class Presto(Dialect): TIME_MAPPING = MySQL.TIME_MAPPING STRICT_STRING_CONCAT = True SUPPORTS_SEMI_ANTI_JOIN = False + TYPED_DIVISION = True # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 # https://github.com/prestodb/presto/issues/2863 - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE class Tokenizer(tokens.Tokenizer): KEYWORDS = { @@ -229,6 +248,7 @@ class Presto(Dialect): ), "ROW": exp.Struct.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, + "SET_AGG": exp.ArrayUniqueAgg.from_arg_list, "SPLIT_TO_MAP": exp.StrToMap.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) @@ -253,6 +273,7 @@ class Presto(Dialect): NVL2_SUPPORTED = False STRUCT_DELIMITER = ("(", ")") LIMIT_ONLY_LITERALS = True + SUPPORTS_SINGLE_ARG_CONCAT = False PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, @@ -284,6 +305,7 @@ class Presto(Dialect): exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), + exp.ArrayUniqueAgg: rename_func("SET_AGG"), exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", @@ -298,7 +320,7 @@ class Presto(Dialect): exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), - exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", + exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", @@ -330,9 +352,6 @@ class Presto(Dialect): exp.Quantile: _quantile_sql, exp.RegexpExtract: regexp_extract_sql, exp.Right: right_to_substring_sql, - exp.SafeBracket: lambda self, e: self.func( - "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) - ), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( @@ -361,10 +380,11 @@ class Presto(Dialect): exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsDiff: _ts_or_ds_diff_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.Unhex: rename_func("FROM_HEX"), exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", - exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), @@ -374,8 +394,24 @@ class Presto(Dialect): exp.Xor: bool_xor_sql, } + def bracket_sql(self, expression: exp.Bracket) -> str: + if expression.args.get("safe"): + return self.func( + "ELEMENT_AT", + expression.this, + seq_get( + apply_index_offset( + expression.this, + expression.expressions, + 1 - expression.args.get("offset", 0), + ), + 0, + ), + ) + return super().bracket_sql(expression) + def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions): + if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions): self.unsupported("Struct with key-value definitions is unsupported.") return self.function_fallback_sql(expression) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 6c7ba35..7382e7c 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -4,8 +4,10 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( + NormalizationStrategy, concat_to_dpipe_sql, concat_ws_to_dpipe_sql, + date_delta_sql, generatedasidentitycolumnconstraint_sql, rename_func, ts_or_ds_to_date_sql, @@ -14,30 +16,28 @@ from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: return f'{self.sql(expression, "this")}."{expression.expression.name}"' -def _parse_date_add(args: t.List) -> exp.DateAdd: - return exp.DateAdd( - this=exp.TsOrDsToDate(this=seq_get(args, 2)), - expression=seq_get(args, 1), - unit=seq_get(args, 0), - ) +def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: + def _parse_delta(args: t.List) -> E: + expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) + if expr_type is exp.TsOrDsAdd: + expr.set("return_type", exp.DataType.build("TIMESTAMP")) + return expr -def _parse_datediff(args: t.List) -> exp.DateDiff: - return exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 2)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), - unit=seq_get(args, 0), - ) + return _parse_delta class Redshift(Postgres): # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE SUPPORTS_USER_DEFINED_TYPES = False INDEX_OFFSET = 0 @@ -52,15 +52,16 @@ class Redshift(Postgres): class Parser(Postgres.Parser): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, - "ADD_MONTHS": lambda args: exp.DateAdd( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), + "ADD_MONTHS": lambda args: exp.TsOrDsAdd( + this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.var("month"), + return_type=exp.DataType.build("TIMESTAMP"), ), - "DATEADD": _parse_date_add, - "DATE_ADD": _parse_date_add, - "DATEDIFF": _parse_datediff, - "DATE_DIFF": _parse_datediff, + "DATEADD": _parse_date_delta(exp.TsOrDsAdd), + "DATE_ADD": _parse_date_delta(exp.TsOrDsAdd), + "DATEDIFF": _parse_date_delta(exp.TsOrDsDiff), + "DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff), "LISTAGG": exp.GroupConcat.from_arg_list, "STRTOL": exp.FromBase.from_arg_list, } @@ -169,12 +170,8 @@ class Redshift(Postgres): exp.ConcatWs: concat_ws_to_dpipe_sql, exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})", exp.CurrentTimestamp: lambda self, e: "SYSDATE", - exp.DateAdd: lambda self, e: self.func( - "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this - ), - exp.DateDiff: lambda self, e: self.func( - "DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this - ), + exp.DateAdd: date_delta_sql("DATEADD"), + exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.FromBase: rename_func("STRTOL"), @@ -183,11 +180,12 @@ class Redshift(Postgres): exp.JSONExtractScalar: _json_sql, exp.GroupConcat: rename_func("LISTAGG"), exp.ParseJSON: rename_func("JSON_PARSE"), - exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.TsOrDsAdd: date_delta_sql("DATEADD"), + exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 01f7512..cdbc071 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,9 +3,12 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, binary_from_function, + date_delta_sql, date_trunc_to_time, datestrtodate_sql, format_time_lambda, @@ -21,7 +24,6 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.expressions import Literal from sqlglot.helper import seq_get -from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType @@ -50,7 +52,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, elif second_arg.name == "3": timescale = exp.UnixToTime.MILLIS elif second_arg.name == "9": - timescale = exp.UnixToTime.MICROS + timescale = exp.UnixToTime.NANOS return exp.UnixToTime(this=first_arg, scale=timescale) @@ -95,14 +97,17 @@ def _parse_datediff(args: t.List) -> exp.DateDiff: def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") - if scale in [None, exp.UnixToTime.SECONDS]: + if scale in (None, exp.UnixToTime.SECONDS): return f"TO_TIMESTAMP({timestamp})" if scale == exp.UnixToTime.MILLIS: return f"TO_TIMESTAMP({timestamp}, 3)" if scale == exp.UnixToTime.MICROS: + return f"TO_TIMESTAMP({timestamp} / 1000, 3)" + if scale == exp.UnixToTime.NANOS: return f"TO_TIMESTAMP({timestamp}, 9)" - raise ValueError("Improper scale for timestamp") + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" # https://docs.snowflake.com/en/sql-reference/functions/date_part.html @@ -201,7 +206,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser] class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax - RESOLVES_IDENTIFIERS_AS_UPPERCASE = True + NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False @@ -236,6 +241,18 @@ class Snowflake(Dialect): "ff6": "%f", } + def quote_identifier(self, expression: E, identify: bool = True) -> E: + # This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an + # unquoted DUAL keyword in a special way and does not map it to a user-defined table + if ( + isinstance(expression, exp.Identifier) + and isinstance(expression.parent, exp.Table) + and expression.name.lower() == "dual" + ): + return t.cast(E, expression) + + return super().quote_identifier(expression, identify=identify) + class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True @@ -245,6 +262,9 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, + "ARRAY_CONTAINS": lambda args: exp.ArrayContains( + this=seq_get(args, 1), expression=seq_get(args, 0) + ), "ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries( # ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive start=seq_get(args, 0), @@ -296,8 +316,8 @@ class Snowflake(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, - TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny), - TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny), + TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), + TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), } ALTER_PARSERS = { @@ -317,6 +337,11 @@ class Snowflake(Dialect): TokenType.SHOW: lambda self: self._parse_show(), } + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "LOCATION": lambda self: self._parse_location(), + } + SHOW_PARSERS = { "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), @@ -349,7 +374,7 @@ class Snowflake(Dialect): table: t.Optional[exp.Expression] = None if self._match_text_seq("@"): table_name = "@" - while True: + while self._curr: self._advance() table_name += self._prev.text if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False): @@ -411,6 +436,20 @@ class Snowflake(Dialect): self._match_text_seq("WITH") return self.expression(exp.SwapTable, this=self._parse_table(schema=True)) + def _parse_location(self) -> exp.LocationProperty: + self._match(TokenType.EQ) + + parts = [self._parse_var(any_token=True)] + + while self._match(TokenType.SLASH): + if self._curr and self._prev.end + 1 == self._curr.start: + parts.append(self._parse_var(any_token=True)) + else: + parts.append(exp.Var(this="")) + return self.expression( + exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts)) + ) + class Tokenizer(tokens.Tokenizer): STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] @@ -457,6 +496,7 @@ class Snowflake(Dialect): AGGREGATE_FILTER_SUPPORTED = False SUPPORTS_TABLE_COPY = False COLLATE_IS_FUNC = True + LIMIT_ONLY_LITERALS = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -464,15 +504,14 @@ class Snowflake(Dialect): exp.ArgMin: rename_func("MIN_BY"), exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), exp.BitwiseXor: rename_func("BITXOR"), - exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), - exp.DateDiff: lambda self, e: self.func( - "DATEDIFF", e.text("unit"), e.expression, e.this - ), + exp.DateAdd: date_delta_sql("DATEADD"), + exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -501,10 +540,11 @@ class Snowflake(Dialect): exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, - transforms.explode_to_unnest(0), + transforms.explode_to_unnest(), transforms.eliminate_semi_and_anti_joins, ] ), + exp.SHA: rename_func("SHA1"), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), exp.StrPosition: lambda self, e: self.func( @@ -524,6 +564,8 @@ class Snowflake(Dialect): exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), + exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -547,6 +589,20 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def trycast_sql(self, expression: exp.TryCast) -> str: + value = expression.this + + if value.type is None: + from sqlglot.optimizer.annotate_types import annotate_types + + value = annotate_types(value) + + if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN): + return super().trycast_sql(expression) + + # TRY_CAST only works for string values in Snowflake + return self.cast_sql(expression) + def log_sql(self, expression: exp.Log) -> str: if not expression.expression: return self.func("LN", expression.this) @@ -554,24 +610,28 @@ class Snowflake(Dialect): return super().log_sql(expression) def unnest_sql(self, expression: exp.Unnest) -> str: - selects = ["value"] unnest_alias = expression.args.get("alias") - offset = expression.args.get("offset") - if offset: - if unnest_alias: - unnest_alias.append("columns", offset.pop()) - - selects.append("index") - subquery = exp.Subquery( - this=exp.select(*selects).from_( - f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))" - ), - ) + columns = [ + exp.to_identifier("seq"), + exp.to_identifier("key"), + exp.to_identifier("path"), + offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"), + seq_get(unnest_alias.columns if unnest_alias else [], 0) + or exp.to_identifier("value"), + exp.to_identifier("this"), + ] + + if unnest_alias: + unnest_alias.set("columns", columns) + else: + unnest_alias = exp.TableAlias(this="_u", columns=columns) + + explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))" alias = self.sql(unnest_alias) alias = f" AS {alias}" if alias else "" - return f"{self.sql(subquery)}{alias}" + return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: scope = self.sql(expression, "scope") @@ -632,3 +692,6 @@ class Snowflake(Dialect): def swaptable_sql(self, expression: exp.SwapTable) -> str: this = self.sql(expression, "this") return f"SWAP WITH {this}" + + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 1abfce6..ba73ac0 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -56,15 +56,17 @@ class Spark(Spark2): def _parse_generated_as_identity( self, - ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint: + ) -> ( + exp.GeneratedAsIdentityColumnConstraint + | exp.ComputedColumnConstraint + | exp.GeneratedAsRowColumnConstraint + ): this = super()._parse_generated_as_identity() if this.expression: return self.expression(exp.ComputedColumnConstraint, this=this.expression) return this class Generator(Spark2.Generator): - SUPPORTS_NESTED_CTES = True - TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, exp.DataType.Type.MONEY: "DECIMAL(15, 4)", diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index da84bd8..aa09f53 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -48,8 +48,11 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str return f"TIMESTAMP_MILLIS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"TIMESTAMP_MICROS({timestamp})" + if scale == exp.UnixToTime.NANOS: + return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)" - raise ValueError("Improper scale for timestamp") + self.unsupported(f"Unsupported scale for timestamp: {scale}.") + return "" def _unalias_pivot(expression: exp.Expression) -> exp.Expression: @@ -119,7 +122,11 @@ class Spark2(Hive): "DOUBLE": _parse_as_cast("double"), "FLOAT": _parse_as_cast("float"), "FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone( - this=exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("timestamp")), + this=exp.cast_unless( + seq_get(args, 0) or exp.Var(this=""), + exp.DataType.build("timestamp"), + exp.DataType.build("timestamp"), + ), zone=seq_get(args, 1), ), "IIF": exp.If.from_arg_list, @@ -224,6 +231,19 @@ class Spark2(Hive): WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False + def struct_sql(self, expression: exp.Struct) -> str: + args = [] + for arg in expression.expressions: + if isinstance(arg, self.KEY_VALUE_DEFINITONS): + if isinstance(arg, exp.Bracket): + args.append(exp.alias_(arg.this, arg.expressions[0].name)) + else: + args.append(exp.alias_(arg.expression, arg.this.name)) + else: + args.append(arg) + + return self.func("STRUCT", *args) + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: # spark2, spark, Databricks require a storage provider for temporary tables provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 1fa730d..e55a3b8 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, @@ -63,8 +64,10 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: class SQLite(Dialect): # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE SUPPORTS_SEMI_ANTI_JOIN = False + TYPED_DIVISION = True + SAFE_DIVISION = True class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -124,7 +127,6 @@ class SQLite(Dialect): exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), exp.Pivot: no_pivot_sql, - exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index e8162c2..141d9c0 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -9,6 +9,7 @@ from sqlglot.tokens import TokenType class Teradata(Dialect): SUPPORTS_SEMI_ANTI_JOIN = False + TYPED_DIVISION = True TIME_MAPPING = { "Y": "%Y", @@ -33,8 +34,10 @@ class Teradata(Dialect): class Tokenizer(tokens.Tokenizer): # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance + # https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "**": TokenType.DSTAR, "^=": TokenType.NEQ, "BYTEINT": TokenType.SMALLINT, "COLLECT": TokenType.COMMAND, @@ -112,10 +115,16 @@ class Teradata(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + # https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Data-Type-Conversions/TRYCAST + "TRYCAST": parser.Parser.FUNCTION_PARSERS["TRY_CAST"], "RANGE_N": lambda self: self._parse_rangen(), "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), } + EXPONENT = { + TokenType.DSTAR: exp.Pow, + } + def _parse_translate(self, strict: bool) -> exp.Expression: this = self._parse_conjunction() @@ -177,6 +186,7 @@ class Teradata(Dialect): exp.ArgMin: rename_func("MIN_BY"), exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.Pow: lambda self, e: self.binary(e, "**"), exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), @@ -192,6 +202,9 @@ class Teradata(Dialect): return super().cast_sql(expression, safe_prefix=safe_prefix) + def trycast_sql(self, expression: exp.TryCast) -> str: + return self.cast_sql(expression, safe_prefix="TRY") + def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a281297..c3d4f0a 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -7,7 +7,9 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, any_value_to_max_sql, + date_delta_sql, generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, @@ -135,11 +137,7 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: return exp.func("HASHBYTES", *args) -def generate_date_delta_with_unit_sql( - self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff -) -> str: - func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" - return self.func(func, expression.text("unit"), expression.expression, expression.this) +DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"} def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: @@ -153,6 +151,11 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt ) ) ) + + # There is no format for "quarter" + if fmt.name.lower() in DATEPART_ONLY_FORMATS: + return self.func("DATEPART", fmt.name, expression.this) + return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) @@ -202,18 +205,50 @@ def _parse_date_delta( return inner_func +def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: + """Ensures all (unnamed) output columns are aliased for CTEs and Subqueries.""" + alias = expression.args.get("alias") + + if ( + isinstance(expression, (exp.CTE, exp.Subquery)) + and isinstance(alias, exp.TableAlias) + and not alias.columns + ): + from sqlglot.optimizer.qualify_columns import qualify_outputs + + # We keep track of the unaliased column projection indexes instead of the expressions + # themselves, because the latter are going to be replaced by new nodes when the aliases + # are added and hence we won't be able to reach these newly added Alias parents + subqueryable = expression.this + unaliased_column_indexes = ( + i + for i, c in enumerate(subqueryable.selects) + if isinstance(c, exp.Column) and not c.alias + ) + + qualify_outputs(subqueryable) + + # Preserve the quoting information of columns for newly added Alias nodes + subqueryable_selects = subqueryable.selects + for select_index in unaliased_column_indexes: + alias = subqueryable_selects[select_index] + column = alias.this + if isinstance(column.this, exp.Identifier): + alias.args["alias"].set("quoted", column.this.quoted) + + return expression + + class TSQL(Dialect): - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None - NULL_ORDERING = "nulls_are_small" + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" SUPPORTS_SEMI_ANTI_JOIN = False LOG_BASE_FIRST = False + TYPED_DIVISION = True + CONCAT_COALESCE = True TIME_MAPPING = { "year": "%Y", - "qq": "%q", - "q": "%q", - "quarter": "%q", "dayofyear": "%j", "day": "%d", "dy": "%d", @@ -320,6 +355,7 @@ class TSQL(Dialect): IDENTIFIERS = ['"', ("[", "]")] QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] + VAR_SINGLE_TOKENS = {"@", "$", "#"} KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -403,9 +439,7 @@ class TSQL(Dialect): LOG_DEFAULTS_TO_LN = True - CONCAT_NULL_OUTPUTS_STRING = True - - ALTER_TABLE_ADD_COLUMN_KEYWORD = False + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -433,7 +467,7 @@ class TSQL(Dialect): """ rollback = self._prev.token_type == TokenType.ROLLBACK - self._match_texts({"TRAN", "TRANSACTION"}) + self._match_texts(("TRAN", "TRANSACTION")) this = self._parse_id_var() if rollback: @@ -579,23 +613,35 @@ class TSQL(Dialect): return super()._parse_if() def _parse_unique(self) -> exp.UniqueColumnConstraint: - return self.expression( - exp.UniqueColumnConstraint, - this=None - if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"} - else self._parse_schema(self._parse_id_var(any_token=False)), - ) + if self._match_texts(("CLUSTERED", "NONCLUSTERED")): + this = self.CONSTRAINT_PARSERS[self._prev.text.upper()](self) + else: + this = self._parse_schema(self._parse_id_var(any_token=False)) + + return self.expression(exp.UniqueColumnConstraint, this=this) class Generator(generator.Generator): LIMIT_IS_TOP = True QUERY_HINTS = False RETURNING_END = False NVL2_SUPPORTED = False - ALTER_TABLE_ADD_COLUMN_KEYWORD = False + ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" COMPUTED_COLUMN_WITH_TYPE = False - SUPPORTS_NESTED_CTES = False CTE_RECURSIVE_KEYWORD_REQUIRED = False + ENSURE_BOOLS = True + NULL_ORDERING_SUPPORTED = False + SUPPORTS_SINGLE_ARG_CONCAT = False + + EXPRESSIONS_WITHOUT_NESTED_CTES = { + exp.Delete, + exp.Insert, + exp.Merge, + exp.Select, + exp.Subquery, + exp.Union, + exp.Update, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -614,14 +660,16 @@ class TSQL(Dialect): **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", - exp.DateAdd: generate_date_delta_with_unit_sql, - exp.DateDiff: generate_date_delta_with_unit_sql, + exp.DateAdd: date_delta_sql("DATEADD"), + exp.DateDiff: date_delta_sql("DATEDIFF"), + exp.CTE: transforms.preprocess([qualify_derived_table_outputs]), exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, @@ -633,15 +681,16 @@ class TSQL(Dialect): transforms.eliminate_qualify, ] ), + exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( - "HASHBYTES", - exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), - e.this, + "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this ), exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), + exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } @@ -690,8 +739,21 @@ class TSQL(Dialect): table = expression.find(exp.Table) + # Convert CTAS statement to SELECT .. INTO .. if kind == "TABLE" and expression.expression: - sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp" + ctas_with = expression.expression.args.get("with") + if ctas_with: + ctas_with = ctas_with.pop() + + subquery = expression.expression + if isinstance(subquery, exp.Subqueryable): + subquery = subquery.subquery() + + select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True)) + select_into.set("into", exp.Into(this=table)) + select_into.set("with", ctas_with) + + sql = self.sql(select_into) if exists: identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index bf2941c..b79a551 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -139,10 +139,16 @@ def interval(this, unit): return datetime.timedelta(**{unit: float(this)}) +@null_if_any("this", "expression") +def arrayjoin(this, expression, null=None): + return expression.join(x for x in (x if x is not None else null for x in this) if x is not None) + + ENV = { "exp": exp, # aggs "ARRAYAGG": list, + "ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))), "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), "MAX": filter_nulls(max), @@ -152,6 +158,7 @@ ENV = { "ABS": null_if_any(lambda this: abs(this)), "ADD": null_if_any(lambda e, this: e + this), "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)), + "ARRAYJOIN": arrayjoin, "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), "BITWISEAND": null_if_any(lambda this, e: this & e), "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), @@ -203,4 +210,9 @@ ENV = { "CURRENTDATE": datetime.date.today, "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)), "TRIM": null_if_any(lambda this, e=None: this.strip(e)), + "STRUCT": lambda *args: { + args[x]: args[x + 1] + for x in range(0, len(args), 2) + if (args[x + 1] is not None and args[x] is not None) + }, } diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index d2ae79d..e1e597d 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -397,6 +397,20 @@ def _lambda_sql(self, e: exp.Lambda) -> str: return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" +def _div_sql(self: generator.Generator, e: exp.Div) -> str: + denominator = self.sql(e, "expression") + + if e.args.get("safe"): + denominator += " or None" + + sql = f"DIV({self.sql(e, 'this')}, {denominator})" + + if e.args.get("typed"): + sql = f"int({sql})" + + return sql + + class Python(Dialect): class Tokenizer(tokens.Tokenizer): STRING_ESCAPES = ["\\"] @@ -413,7 +427,11 @@ class Python(Dialect): exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Concat: lambda self, e: self.func( + "SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions + ), exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", + exp.Div: _div_sql, exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 7931535..87699f8 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -120,20 +120,22 @@ def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict depth = dict_depth(d) if depth > 1: return { - normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect) + normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables( + v, dialect=dialect + ) for k, v in d.items() } result = {} for table_name, table in d.items(): - table_name = normalize_name(table_name, dialect=dialect) + table_name = normalize_name(table_name, dialect=dialect).name if isinstance(table, Table): result[table_name] = table else: table = [ { - normalize_name(column_name, dialect=dialect): value + normalize_name(column_name, dialect=dialect).name: value for column_name, value in row.items() } for row in table diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 99ebfb3..99722be 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -53,6 +53,7 @@ class _Expression(type): SQLGLOT_META = "sqlglot.meta" +TABLE_PARTS = ("this", "db", "catalog") class Expression(metaclass=_Expression): @@ -134,7 +135,7 @@ class Expression(metaclass=_Expression): return self.args.get("expression") @property - def expressions(self): + def expressions(self) -> t.List[t.Any]: """ Retrieves the argument with key "expressions". """ @@ -238,6 +239,9 @@ class Expression(metaclass=_Expression): dtype = DataType.build(dtype) self._type = dtype # type: ignore + def is_type(self, *dtypes) -> bool: + return self.type is not None and self.type.is_type(*dtypes) + @property def meta(self) -> t.Dict[str, t.Any]: if self._meta is None: @@ -481,7 +485,7 @@ class Expression(metaclass=_Expression): def flatten(self, unnest=True): """ - Returns a generator which yields child nodes who's parents are the same class. + Returns a generator which yields child nodes whose parents are the same class. A AND B AND C -> [A, B, C] """ @@ -508,7 +512,7 @@ class Expression(metaclass=_Expression): """ from sqlglot.dialects import Dialect - return Dialect.get_or_raise(dialect)().generate(self, **opts) + return Dialect.get_or_raise(dialect).generate(self, **opts) def _to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" @@ -821,6 +825,12 @@ class Expression(metaclass=_Expression): def rlike(self, other: ExpOrStr) -> RegexpLike: return self._binop(RegexpLike, other) + def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div: + div = self._binop(Div, other) + div.args["typed"] = typed + div.args["safe"] = safe + return div + def __lt__(self, other: t.Any) -> LT: return self._binop(LT, other) @@ -1000,7 +1010,6 @@ class UDTF(DerivedTable, Unionable): class Cache(Expression): arg_types = { - "with": False, "this": True, "lazy": False, "options": False, @@ -1012,6 +1021,10 @@ class Uncache(Expression): arg_types = {"this": True, "exists": False} +class Refresh(Expression): + pass + + class DDL(Expression): @property def ctes(self): @@ -1033,6 +1046,43 @@ class DDL(Expression): return [] +class DML(Expression): + def returning( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> DML: + """ + Set the RETURNING expression. Not supported by all dialects. + + Example: + >>> delete("tbl").returning("*", dialect="postgres").sql() + 'DELETE FROM tbl RETURNING *' + + Args: + expression: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="returning", + prefix="RETURNING", + dialect=dialect, + copy=copy, + into=Returning, + **opts, + ) + + class Create(DDL): arg_types = { "with": False, @@ -1133,8 +1183,10 @@ class WithinGroup(Expression): arg_types = {"this": True, "expression": False} +# clickhouse supports scalar ctes +# https://clickhouse.com/docs/en/sql-reference/statements/select/with class CTE(DerivedTable): - arg_types = {"this": True, "alias": True} + arg_types = {"this": True, "alias": True, "scalar": False} class TableAlias(Expression): @@ -1297,6 +1349,10 @@ class AutoIncrementColumnConstraint(ColumnConstraintKind): pass +class PeriodForSystemTimeConstraint(ColumnConstraintKind): + arg_types = {"this": True, "expression": True} + + class CaseSpecificColumnConstraint(ColumnConstraintKind): arg_types = {"not_": True} @@ -1351,6 +1407,10 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): } +class GeneratedAsRowColumnConstraint(ColumnConstraintKind): + arg_types = {"start": True, "hidden": False} + + # https://dev.mysql.com/doc/refman/8.0/en/create-table.html class IndexColumnConstraint(ColumnConstraintKind): arg_types = { @@ -1383,6 +1443,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind): pass +# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters +class TransformColumnConstraint(ColumnConstraintKind): + pass + + class PrimaryKeyColumnConstraint(ColumnConstraintKind): arg_types = {"desc": False} @@ -1413,7 +1478,7 @@ class Constraint(Expression): arg_types = {"this": True, "expressions": True} -class Delete(Expression): +class Delete(DML): arg_types = { "with": False, "this": False, @@ -1496,41 +1561,6 @@ class Delete(Expression): **opts, ) - def returning( - self, - expression: ExpOrStr, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Delete: - """ - Set the RETURNING expression. Not supported by all dialects. - - Example: - >>> delete("tbl").returning("*", dialect="postgres").sql() - 'DELETE FROM tbl RETURNING *' - - Args: - expression: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Delete: the modified expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="returning", - prefix="RETURNING", - dialect=dialect, - copy=copy, - into=Returning, - **opts, - ) - class Drop(Expression): arg_types = { @@ -1648,7 +1678,7 @@ class Index(Expression): } -class Insert(DDL): +class Insert(DDL, DML): arg_types = { "with": False, "this": True, @@ -2259,6 +2289,11 @@ class WithJournalTableProperty(Property): arg_types = {"this": True} +class WithSystemVersioningProperty(Property): + # this -> history table name, expression -> data consistency check + arg_types = {"this": False, "expression": False} + + class Properties(Expression): arg_types = {"expressions": True} @@ -3663,6 +3698,7 @@ class DataType(Expression): Type.BIGINT, Type.INT128, Type.INT256, + Type.BIT, } FLOAT_TYPES = { @@ -3692,7 +3728,7 @@ class DataType(Expression): @classmethod def build( cls, - dtype: str | DataType | DataType.Type, + dtype: DATA_TYPE, dialect: DialectType = None, udt: bool = False, **kwargs, @@ -3733,7 +3769,7 @@ class DataType(Expression): return DataType(**{**data_type_exp.args, **kwargs}) - def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: + def is_type(self, *dtypes: DATA_TYPE) -> bool: """ Checks whether this DataType matches one of the provided data types. Nested types or precision will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>. @@ -3761,6 +3797,9 @@ class DataType(Expression): return False +DATA_TYPE = t.Union[str, DataType, DataType.Type] + + # https://www.postgresql.org/docs/15/datatype-pseudo.html class PseudoType(DataType): arg_types = {"this": True} @@ -3868,7 +3907,7 @@ class BitwiseXor(Binary): class Div(Binary): - pass + arg_types = {"this": True, "expression": True, "typed": False, "safe": False} class Overlaps(Binary): @@ -3892,13 +3931,25 @@ class Dot(Binary): return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions)) + @property + def parts(self) -> t.List[Expression]: + """Return the parts of a table / column in order catalog, db, table.""" + this, *parts = self.flatten() -class DPipe(Binary): - pass + parts.reverse() + for arg in ("this", "table", "db", "catalog"): + part = this.args.get(arg) -class SafeDPipe(DPipe): - pass + if isinstance(part, Expression): + parts.append(part) + + parts.reverse() + return parts + + +class DPipe(Binary): + arg_types = {"this": True, "expression": True, "safe": False} class EQ(Binary, Predicate): @@ -3913,6 +3964,11 @@ class NullSafeNEQ(Binary, Predicate): pass +# Represents e.g. := in DuckDB which is mostly used for setting parameters +class PropertyEQ(Binary): + pass + + class Distance(Binary): pass @@ -3981,6 +4037,11 @@ class NEQ(Binary, Predicate): pass +# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH +class Operator(Binary): + arg_types = {"this": True, "operator": True, "expression": True} + + class SimilarTo(Binary, Predicate): pass @@ -4048,7 +4109,8 @@ class Between(Predicate): class Bracket(Condition): - arg_types = {"this": True, "expressions": True} + # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator + arg_types = {"this": True, "expressions": True, "offset": False, "safe": False} @property def output_name(self) -> str: @@ -4058,10 +4120,6 @@ class Bracket(Condition): return super().output_name -class SafeBracket(Bracket): - """Represents array lookup where OOB index yields NULL instead of causing a failure.""" - - class Distinct(Expression): arg_types = {"expressions": False, "on": False} @@ -4077,6 +4135,11 @@ class In(Predicate): } +# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in +class ForIn(Expression): + arg_types = {"this": True, "expression": True} + + class TimeUnit(Expression): """Automatically converts unit arg into a var.""" @@ -4248,8 +4311,9 @@ class Array(Func): # https://docs.snowflake.com/en/sql-reference/functions/to_char +# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html class ToChar(Func): - arg_types = {"this": True, "format": False} + arg_types = {"this": True, "format": False, "nlsparam": False} class GenerateSeries(Func): @@ -4260,6 +4324,10 @@ class ArrayAgg(AggFunc): pass +class ArrayUniqueAgg(AggFunc): + pass + + class ArrayAll(Func): arg_types = {"this": True, "expression": True} @@ -4358,7 +4426,7 @@ class Cast(Func): def output_name(self) -> str: return self.name - def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool: + def is_type(self, *dtypes: DATA_TYPE) -> bool: """ Checks whether this Cast's DataType matches one of the provided data types. Nested types like arrays or structs will be compared using "structural equivalence" semantics, so e.g. @@ -4403,14 +4471,10 @@ class Chr(Func): class Concat(Func): - arg_types = {"expressions": True} + arg_types = {"expressions": True, "safe": False, "coalesce": False} is_var_len_args = True -class SafeConcat(Concat): - pass - - class ConcatWs(Concat): _sql_names = ["CONCAT_WS"] @@ -4643,6 +4707,10 @@ class If(Func): arg_types = {"this": True, "true": True, "false": False} +class Nullif(Func): + arg_types = {"this": True, "expression": True} + + class Initcap(Func): arg_types = {"this": True, "expression": False} @@ -4651,6 +4719,10 @@ class IsNan(Func): _sql_names = ["IS_NAN", "ISNAN"] +class IsInf(Func): + _sql_names = ["IS_INF", "ISINF"] + + class FormatJson(Expression): pass @@ -4970,10 +5042,6 @@ class SafeDivide(Func): arg_types = {"this": True, "expression": True} -class SetAgg(AggFunc): - pass - - class SHA(Func): _sql_names = ["SHA", "SHA1"] @@ -5118,6 +5186,15 @@ class Trim(Func): class TsOrDsAdd(Func, TimeUnit): + # return_type is used to correctly cast the arguments of this expression when transpiling it + arg_types = {"this": True, "expression": True, "unit": False, "return_type": False} + + @property + def return_type(self) -> DataType: + return DataType.build(self.args.get("return_type") or DataType.Type.DATE) + + +class TsOrDsDiff(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -5149,6 +5226,7 @@ class UnixToTime(Func): SECONDS = Literal.string("seconds") MILLIS = Literal.string("millis") MICROS = Literal.string("micros") + NANOS = Literal.string("nanos") class UnixToTimeStr(Func): @@ -5202,6 +5280,7 @@ def _norm_arg(arg): ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) +FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()} # Helpers @@ -5693,7 +5772,9 @@ def delete( if where: delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts) if returning: - delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts) + delete_expr = t.cast( + Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts) + ) return delete_expr @@ -5702,6 +5783,7 @@ def insert( into: ExpOrStr, columns: t.Optional[t.Sequence[ExpOrStr]] = None, overwrite: t.Optional[bool] = None, + returning: t.Optional[ExpOrStr] = None, dialect: DialectType = None, copy: bool = True, **opts, @@ -5718,6 +5800,7 @@ def insert( into: the tbl to insert data to. columns: optionally the table's column names. overwrite: whether to INSERT OVERWRITE or not. + returning: sql conditional parsed into a RETURNING statement dialect: the dialect used to parse the input expressions. copy: whether or not to copy the expression. **opts: other options to use to parse the input expressions. @@ -5739,7 +5822,12 @@ def insert( **opts, ) - return Insert(this=this, expression=expr, overwrite=overwrite) + insert = Insert(this=this, expression=expr, overwrite=overwrite) + + if returning: + insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts)) + + return insert def condition( @@ -5913,7 +6001,7 @@ def to_identifier(name, quoted=None, copy=True): return identifier -def parse_identifier(name: str, dialect: DialectType = None) -> Identifier: +def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier: """ Parses a given string into an identifier. @@ -5965,7 +6053,7 @@ def to_table(sql_path: None, **kwargs) -> None: def to_table( - sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs + sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs ) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. @@ -5974,13 +6062,14 @@ def to_table( Args: sql_path: a `[catalog].[schema].[table]` string. dialect: the source dialect according to which the table name will be parsed. + copy: Whether or not to copy a table if it is passed in. kwargs: the kwargs to instantiate the resulting `Table` expression with. Returns: A table expression. """ if sql_path is None or isinstance(sql_path, Table): - return sql_path + return maybe_copy(sql_path, copy=copy) if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") @@ -6123,7 +6212,7 @@ def column( ) -def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast: +def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast: """Cast an expression to a data type. Example: @@ -6335,12 +6424,15 @@ def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: } -def table_name(table: Table | str, dialect: DialectType = None) -> str: +def table_name(table: Table | str, dialect: DialectType = None, identify: bool = False) -> str: """Get the full name of a table as a string. Args: table: Table expression node or string. dialect: The dialect to generate the table name for. + identify: Determines when an identifier should be quoted. Possible values are: + False (default): Never quote, except in cases where it's mandatory by the dialect. + True: Always quote. Examples: >>> from sqlglot import exp, parse_one @@ -6358,37 +6450,68 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str: return ".".join( part.sql(dialect=dialect, identify=True) - if not SAFE_IDENTIFIER_RE.match(part.name) + if identify or not SAFE_IDENTIFIER_RE.match(part.name) else part.name for part in table.parts ) -def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E: +def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: bool = True) -> str: + """Returns a case normalized table name without quotes. + + Args: + table: the table to normalize + dialect: the dialect to use for normalization rules + copy: whether or not to copy the expression. + + Examples: + >>> normalize_table_name("`A-B`.c", dialect="bigquery") + 'A-B.c' + """ + from sqlglot.optimizer.normalize_identifiers import normalize_identifiers + + return ".".join( + p.name + for p in normalize_identifiers( + to_table(table, dialect=dialect, copy=copy), dialect=dialect + ).parts + ) + + +def replace_tables( + expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True +) -> E: """Replace all tables in expression according to the mapping. Args: expression: expression node to be transformed and replaced. mapping: mapping of table names. + dialect: the dialect of the mapping table copy: whether or not to copy the expression. Examples: >>> from sqlglot import exp, parse_one >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() - 'SELECT * FROM c' + 'SELECT * FROM c /* a.b */' Returns: The mapped expression. """ + mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()} + def _replace_tables(node: Expression) -> Expression: if isinstance(node, Table): - new_name = mapping.get(table_name(node)) + original = normalize_table_name(node, dialect=dialect) + new_name = mapping.get(original) + if new_name: - return to_table( + table = to_table( new_name, - **{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")}, + **{k: v for k, v in node.args.items() if k not in TABLE_PARTS}, ) + table.add_comments([original]) + return table return node return expression.transform(_replace_tables, copy=copy) @@ -6431,7 +6554,10 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: def expand( - expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True + expression: Expression, + sources: t.Dict[str, Subqueryable], + dialect: DialectType = None, + copy: bool = True, ) -> Expression: """Transforms an expression by expanding all referenced sources into subqueries. @@ -6446,15 +6572,17 @@ def expand( Args: expression: The expression to expand. sources: A dictionary of name to Subqueryables. + dialect: The dialect of the sources dict. copy: Whether or not to copy the expression during transformation. Defaults to True. Returns: The transformed expression. """ + sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()} def _expand(node: Expression): if isinstance(node, Table): - name = table_name(node) + name = normalize_table_name(node, dialect=dialect) source = sources.get(name) if source: subquery = source.subquery(node.alias or name) @@ -6465,7 +6593,7 @@ def expand( return expression.transform(_expand, copy=copy) -def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: +def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func: """ Returns a Func expression. @@ -6479,6 +6607,7 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: Args: name: the name of the function to build. args: the args used to instantiate the function of interest. + copy: whether or not to copy the argument expressions. dialect: the source dialect. kwargs: the kwargs used to instantiate the function of interest. @@ -6494,14 +6623,29 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: from sqlglot.dialects.dialect import Dialect - converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args] - kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()} + dialect = Dialect.get_or_raise(dialect) - parser = Dialect.get_or_raise(dialect)().parser() - from_args_list = parser.FUNCTIONS.get(name.upper()) + converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args] + kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()} - if from_args_list: - function = from_args_list(converted) if converted else from_args_list.__self__(**kwargs) # type: ignore + constructor = dialect.parser_class.FUNCTIONS.get(name.upper()) + if constructor: + if converted: + if "dialect" in constructor.__code__.co_varnames: + function = constructor(converted, dialect=dialect) + else: + function = constructor(converted) + elif constructor.__name__ == "from_arg_list": + function = constructor.__self__(**kwargs) # type: ignore + else: + constructor = FUNCTION_BY_NAME.get(name.upper()) + if constructor: + function = constructor(**kwargs) + else: + raise ValueError( + f"Unable to convert '{name}' into a Func. Either manually construct " + "the Func expression of interest or parse the function call." + ) else: kwargs = kwargs or {"expressions": converted} function = Anonymous(this=name, **kwargs) @@ -6512,6 +6656,48 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: return function +def case( + expression: t.Optional[ExpOrStr] = None, + **opts, +) -> Case: + """ + Initialize a CASE statement. + + Example: + case().when("a = 1", "foo").else_("bar") + + Args: + expression: Optionally, the input expression (not all dialects support this) + **opts: Extra keyword arguments for parsing `expression` + """ + if expression is not None: + this = maybe_parse(expression, **opts) + else: + this = None + return Case(this=this, ifs=[]) + + +def cast_unless( + expression: ExpOrStr, + to: DATA_TYPE, + *types: DATA_TYPE, + **opts: t.Any, +) -> Expression | Cast: + """ + Cast an expression to a data type unless it is a specified type. + + Args: + expression: The expression to cast. + to: The data type to cast to. + **types: The types to exclude from casting. + **opts: Extra keyword arguments for parsing `expression` + """ + expr = maybe_parse(expression, **opts) + if expr.is_type(*types): + return expr + return cast(expr, to, **opts) + + def true() -> Boolean: """ Returns a true Boolean expression. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 4916cf8..f3f9060 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -9,10 +9,11 @@ from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType if t.TYPE_CHECKING: from sqlglot._typing import E + from sqlglot.dialects.dialect import DialectType logger = logging.getLogger("sqlglot") @@ -58,9 +59,6 @@ class Generator: exp.DateAdd: lambda self, e: self.func( "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) ), - exp.TsOrDsAdd: lambda self, e: self.func( - "TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) - ), exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", @@ -108,9 +106,6 @@ class Generator: exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", } - # Whether the base comes first - LOG_BASE_FIRST = True - # Whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True @@ -201,7 +196,7 @@ class Generator: VALUES_AS_TABLE = True # Whether or not the word COLUMN is included when adding a column with ALTER TABLE - ALTER_TABLE_ADD_COLUMN_KEYWORD = True + ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True # UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery) UNNEST_WITH_ORDINALITY = True @@ -212,9 +207,6 @@ class Generator: # Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds SEMI_ANTI_JOIN_WITH_SIDE = True - # Whether or not session variables / parameters are supported, e.g. @x in T-SQL - SUPPORTS_PARAMETERS = True - # Whether or not to include the type of a computed column in the CREATE DDL COMPUTED_COLUMN_WITH_TYPE = True @@ -230,12 +222,15 @@ class Generator: # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle) DATA_TYPE_SPECIFIERS_ALLOWED = False - # Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed - SUPPORTS_NESTED_CTES = True + # Whether or not conditions require booleans WHERE x = 0 vs WHERE x + ENSURE_BOOLS = False # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs CTE_RECURSIVE_KEYWORD_REQUIRED = True + # Whether or not CONCAT requires >1 arguments + SUPPORTS_SINGLE_ARG_CONCAT = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -335,6 +330,7 @@ class Generator: exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, + exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA, } # Keywords that can't be used as unquoted identifier names @@ -368,36 +364,12 @@ class Generator: exp.Paren, ) - SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" + # Expressions that need to have all CTEs under them bubbled up to them + EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set() + + KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice) - # Autofilled - INVERSE_TIME_MAPPING: t.Dict[str, str] = {} - INVERSE_TIME_TRIE: t.Dict = {} - INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} - INDEX_OFFSET = 0 - UNNEST_COLUMN_ONLY = False - ALIAS_POST_TABLESAMPLE = False - IDENTIFIERS_CAN_START_WITH_DIGIT = False - STRICT_STRING_CONCAT = False - NORMALIZE_FUNCTIONS: bool | str = "upper" - NULL_ORDERING = "nulls_are_small" - - can_identify: t.Callable[[str, str | bool], bool] - - # Delimiters for quotes, identifiers and the corresponding escape characters - QUOTE_START = "'" - QUOTE_END = "'" - IDENTIFIER_START = '"' - IDENTIFIER_END = '"' - TOKENIZER_CLASS = Tokenizer - - # Delimiters for bit, hex, byte and raw literals - BIT_START: t.Optional[str] = None - BIT_END: t.Optional[str] = None - HEX_START: t.Optional[str] = None - HEX_END: t.Optional[str] = None - BYTE_START: t.Optional[str] = None - BYTE_END: t.Optional[str] = None + SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" __slots__ = ( "pretty", @@ -411,6 +383,7 @@ class Generator: "leading_comma", "max_text_width", "comments", + "dialect", "unsupported_messages", "_escaped_quote_end", "_escaped_identifier_end", @@ -429,8 +402,10 @@ class Generator: leading_comma: bool = False, max_text_width: int = 80, comments: bool = True, + dialect: DialectType = None, ): import sqlglot + from sqlglot.dialects import Dialect self.pretty = pretty if pretty is not None else sqlglot.pretty self.identify = identify @@ -442,16 +417,19 @@ class Generator: self.leading_comma = leading_comma self.max_text_width = max_text_width self.comments = comments + self.dialect = Dialect.get_or_raise(dialect) # This is both a Dialect property and a Generator argument, so we prioritize the latter self.normalize_functions = ( - self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions + self.dialect.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions ) self.unsupported_messages: t.List[str] = [] - self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END + self._escaped_quote_end: str = ( + self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END + ) self._escaped_identifier_end: str = ( - self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END + self.dialect.tokenizer_class.IDENTIFIER_ESCAPES[0] + self.dialect.IDENTIFIER_END ) def generate(self, expression: exp.Expression, copy: bool = True) -> str: @@ -469,23 +447,14 @@ class Generator: if copy: expression = expression.copy() - # Some dialects only support CTEs at the top level expression, so we need to bubble up nested - # CTEs to that level in order to produce a syntactically valid expression. This transformation - # happens here to minimize code duplication, since many expressions support CTEs. - if ( - not self.SUPPORTS_NESTED_CTES - and isinstance(expression, exp.Expression) - and not expression.parent - and "with" in expression.arg_types - and any(node.parent is not expression for node in expression.find_all(exp.With)) - ): - from sqlglot.transforms import move_ctes_to_top_level - - expression = move_ctes_to_top_level(expression) + expression = self.preprocess(expression) self.unsupported_messages = [] sql = self.sql(expression).strip() + if self.pretty: + sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n") + if self.unsupported_level == ErrorLevel.IGNORE: return sql @@ -495,10 +464,26 @@ class Generator: elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) - if self.pretty: - sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n") return sql + def preprocess(self, expression: exp.Expression) -> exp.Expression: + """Apply generic preprocessing transformations to a given expression.""" + if ( + not expression.parent + and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES + and any(node.parent is not expression for node in expression.find_all(exp.With)) + ): + from sqlglot.transforms import move_ctes_to_top_level + + expression = move_ctes_to_top_level(expression) + + if self.ENSURE_BOOLS: + from sqlglot.transforms import ensure_bools + + expression = ensure_bools(expression) + + return expression + def unsupported(self, message: str) -> None: if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) @@ -752,9 +737,24 @@ class Generator: return f"GENERATED{this} AS {expr}{sequence_opts}" + def generatedasrowcolumnconstraint_sql( + self, expression: exp.GeneratedAsRowColumnConstraint + ) -> str: + start = "START" if expression.args["start"] else "END" + hidden = " HIDDEN" if expression.args.get("hidden") else "" + return f"GENERATED ALWAYS AS ROW {start}{hidden}" + + def periodforsystemtimeconstraint_sql( + self, expression: exp.PeriodForSystemTimeConstraint + ) -> str: + return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})" + def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" + def transformcolumnconstraint_sql(self, expression: exp.TransformColumnConstraint) -> str: + return f"AS {self.sql(expression, 'this')}" + def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str: desc = expression.args.get("desc") if desc is not None: @@ -900,32 +900,32 @@ class Generator: columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" - if not alias and not self.UNNEST_COLUMN_ONLY: + if not alias and not self.dialect.UNNEST_COLUMN_ONLY: alias = "_t" return f"{alias}{columns}" def bitstring_sql(self, expression: exp.BitString) -> str: this = self.sql(expression, "this") - if self.BIT_START: - return f"{self.BIT_START}{this}{self.BIT_END}" + if self.dialect.BIT_START: + return f"{self.dialect.BIT_START}{this}{self.dialect.BIT_END}" return f"{int(this, 2)}" def hexstring_sql(self, expression: exp.HexString) -> str: this = self.sql(expression, "this") - if self.HEX_START: - return f"{self.HEX_START}{this}{self.HEX_END}" + if self.dialect.HEX_START: + return f"{self.dialect.HEX_START}{this}{self.dialect.HEX_END}" return f"{int(this, 16)}" def bytestring_sql(self, expression: exp.ByteString) -> str: this = self.sql(expression, "this") - if self.BYTE_START: - return f"{self.BYTE_START}{this}{self.BYTE_END}" + if self.dialect.BYTE_START: + return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}" return this def rawstring_sql(self, expression: exp.RawString) -> str: string = self.escape_str(expression.this.replace("\\", "\\\\")) - return f"{self.QUOTE_START}{string}{self.QUOTE_END}" + return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}" def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: this = self.sql(expression, "this") @@ -1065,14 +1065,14 @@ class Generator: text = expression.name lower = text.lower() text = lower if self.normalize and not expression.quoted else text - text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end) + text = text.replace(self.dialect.IDENTIFIER_END, self._escaped_identifier_end) if ( expression.quoted - or self.can_identify(text, self.identify) + or self.dialect.can_identify(text, self.identify) or lower in self.RESERVED_KEYWORDS - or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit()) + or (not self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit()) ): - text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}" + text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}" return text def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: @@ -1121,7 +1121,7 @@ class Generator: expressions = self.expressions(properties, sep=sep, indent=False) if expressions: expressions = self.wrap(expressions) if wrapped else expressions - return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" + return f"{prefix}{' ' if prefix.strip() else ''}{expressions}{suffix}" return "" def with_properties(self, properties: exp.Properties) -> str: @@ -1286,6 +1286,21 @@ class Generator: statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS" return f"{data_sql}{statistics_sql}" + def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningProperty) -> str: + sql = "WITH(SYSTEM_VERSIONING=ON" + + if expression.this: + history_table = self.sql(expression, "this") + sql = f"{sql}(HISTORY_TABLE={history_table}" + + if expression.expression: + data_consistency_check = self.sql(expression, "expression") + sql = f"{sql}, DATA_CONSISTENCY_CHECK={data_consistency_check}" + + sql = f"{sql})" + + return f"{sql})" + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") @@ -1387,13 +1402,13 @@ class Generator: def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: table = ".".join( - part - for part in [ - self.sql(expression, "catalog"), - self.sql(expression, "db"), - self.sql(expression, "this"), - ] - if part + self.sql(part) + for part in ( + expression.args.get("catalog"), + expression.args.get("db"), + expression.args.get("this"), + ) + if part is not None ) version = self.sql(expression, "version") @@ -1426,7 +1441,7 @@ class Generator: def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: - if self.ALIAS_POST_TABLESAMPLE and expression.this.alias: + if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias: table = expression.this.copy() table.set("alias", None) this = self.sql(table) @@ -1676,12 +1691,16 @@ class Generator: def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: this = self.sql(expression, "this") - args = ", ".join( - self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e) + + args = [ + self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e for e in (expression.args.get(k) for k in ("offset", "expression")) if e - ) - return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}" + ] + + args_sql = ", ".join(self.sql(e) for e in args) + args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql + return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") @@ -1732,13 +1751,13 @@ class Generator: def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: - text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}" + text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}" return text def escape_str(self, text: str) -> str: - text = text.replace(self.QUOTE_END, self._escaped_quote_end) - if self.INVERSE_ESCAPE_SEQUENCES: - text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text) + text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end) + if self.dialect.INVERSE_ESCAPE_SEQUENCES: + text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text) elif self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) return text @@ -1782,9 +1801,11 @@ class Generator: nulls_first = expression.args.get("nulls_first") nulls_last = not nulls_first - nulls_are_large = self.NULL_ORDERING == "nulls_are_large" - nulls_are_small = self.NULL_ORDERING == "nulls_are_small" - nulls_are_last = self.NULL_ORDERING == "nulls_are_last" + nulls_are_large = self.dialect.NULL_ORDERING == "nulls_are_large" + nulls_are_small = self.dialect.NULL_ORDERING == "nulls_are_small" + nulls_are_last = self.dialect.NULL_ORDERING == "nulls_are_last" + + this = self.sql(expression, "this") sort_order = " DESC" if desc else (" ASC" if desc is False else "") nulls_sort_change = "" @@ -1799,13 +1820,13 @@ class Generator: ): nulls_sort_change = " NULLS LAST" + # If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - self.unsupported( - "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" - ) + null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else "" + this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" nulls_sort_change = "" - return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" + return f"{this}{sort_order}{nulls_sort_change}" def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: partition = self.partition_by_sql(expression) @@ -1933,10 +1954,13 @@ class Generator: ) kind = "" + # We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata + # are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first. + top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}" expressions = f"{self.sep()}{expressions}" if expressions else expressions sql = self.query_modifiers( expression, - f"SELECT{top}{hint}{distinct}{kind}{expressions}", + f"SELECT{top_distinct}{kind}{expressions}", self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) @@ -1961,7 +1985,7 @@ class Generator: def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") - return f"{self.PARAMETER_TOKEN}{this}" if self.SUPPORTS_PARAMETERS else this + return f"{self.PARAMETER_TOKEN}{this}" def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") @@ -2009,7 +2033,7 @@ class Generator: if alias and isinstance(offset, exp.Expression): alias.append("columns", offset) - if alias and self.UNNEST_COLUMN_ONLY: + if alias and self.dialect.UNNEST_COLUMN_ONLY: columns = alias.columns alias = self.sql(columns[0]) if columns else "" else: @@ -2080,14 +2104,14 @@ class Generator: return f"{this} BETWEEN {low} AND {high}" def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET) + expressions = apply_index_offset( + expression.this, + expression.expressions, + self.dialect.INDEX_OFFSET - expression.args.get("offset", 0), + ) expressions_sql = ", ".join(self.sql(e) for e in expressions) - return f"{self.sql(expression, 'this')}[{expressions_sql}]" - def safebracket_sql(self, expression: exp.SafeBracket) -> str: - return self.bracket_sql(expression) - def all_sql(self, expression: exp.All) -> str: return f"ALL {self.wrap(expression)}" @@ -2145,12 +2169,33 @@ class Generator: else: return self.func("TRIM", expression.this, expression.expression) - def safeconcat_sql(self, expression: exp.SafeConcat) -> str: - expressions = expression.expressions - if self.STRICT_STRING_CONCAT: - expressions = (exp.cast(e, "text") for e in expressions) + def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[exp.Expression]: + args = expression.expressions + if isinstance(expression, exp.ConcatWs): + args = args[1:] # Skip the delimiter + + if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): + args = [exp.cast(e, "text") for e in args] + + if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"): + args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args] + + return args + + def concat_sql(self, expression: exp.Concat) -> str: + expressions = self.convert_concat_args(expression) + + # Some dialects don't allow a single-argument CONCAT call + if not self.SUPPORTS_SINGLE_ARG_CONCAT and len(expressions) == 1: + return self.sql(expressions[0]) + return self.func("CONCAT", *expressions) + def concatws_sql(self, expression: exp.ConcatWs) -> str: + return self.func( + "CONCAT_WS", seq_get(expression.expressions, 0), *self.convert_concat_args(expression) + ) + def check_sql(self, expression: exp.Check) -> str: this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -2493,14 +2538,7 @@ class Generator: actions = expression.args["actions"] if isinstance(actions[0], exp.ColumnDef): - if self.ALTER_TABLE_ADD_COLUMN_KEYWORD: - actions = self.expressions( - expression, - key="actions", - prefix="ADD COLUMN ", - ) - else: - actions = f"ADD {self.expressions(expression, key='actions')}" + actions = self.add_column_sql(expression) elif isinstance(actions[0], exp.Schema): actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Delete): @@ -2512,6 +2550,15 @@ class Generator: only = " ONLY" if expression.args.get("only") else "" return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}" + def add_column_sql(self, expression: exp.AlterTable) -> str: + if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD: + return self.expressions( + expression, + key="actions", + prefix="ADD COLUMN ", + ) + return f"ADD {self.expressions(expression, key='actions', flat=True)}" + def droppartition_sql(self, expression: exp.DropPartition) -> str: expressions = self.expressions(expression) exists = " IF EXISTS " if expression.args.get("exists") else " " @@ -2551,14 +2598,31 @@ class Generator: ) def dpipe_sql(self, expression: exp.DPipe) -> str: - return self.binary(expression, "||") - - def safedpipe_sql(self, expression: exp.SafeDPipe) -> str: - if self.STRICT_STRING_CONCAT: + if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten())) - return self.dpipe_sql(expression) + return self.binary(expression, "||") def div_sql(self, expression: exp.Div) -> str: + l, r = expression.left, expression.right + + if not self.dialect.SAFE_DIVISION and expression.args.get("safe"): + r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0))) + + if self.dialect.TYPED_DIVISION and not expression.args.get("typed"): + if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type( + *exp.DataType.FLOAT_TYPES + ): + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE)) + + elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"): + if l.is_type(*exp.DataType.INTEGER_TYPES) and r.is_type(*exp.DataType.INTEGER_TYPES): + return self.sql( + exp.cast( + l / r, + to=exp.DataType.Type.BIGINT, + ) + ) + return self.binary(expression, "/") def overlaps_sql(self, expression: exp.Overlaps) -> str: @@ -2573,6 +2637,9 @@ class Generator: def eq_sql(self, expression: exp.EQ) -> str: return self.binary(expression, "=") + def propertyeq_sql(self, expression: exp.PropertyEQ) -> str: + return self.binary(expression, ":=") + def escape_sql(self, expression: exp.Escape) -> str: return self.binary(expression, "ESCAPE") @@ -2641,10 +2708,13 @@ class Generator: return self.cast_sql(expression, safe_prefix="TRY_") def log_sql(self, expression: exp.Log) -> str: - args = list(expression.args.values()) - if not self.LOG_BASE_FIRST: - args.reverse() - return self.func("LOG", *args) + this = expression.this + expr = expression.expression + + if not self.dialect.LOG_BASE_FIRST: + this, expr = expr, this + + return self.func("LOG", this, expr) def use_sql(self, expression: exp.Use) -> str: kind = self.sql(expression, "kind") @@ -2696,7 +2766,9 @@ class Generator: def format_time(self, expression: exp.Expression) -> t.Optional[str]: return format_time( - self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE + self.sql(expression, "format"), + self.dialect.INVERSE_TIME_MAPPING, + self.dialect.INVERSE_TIME_TRIE, ) def expressions( @@ -2963,6 +3035,19 @@ class Generator: parameters = self.sql(expression, "params_struct") return self.func("PREDICT", model, table, parameters or None) + def forin_sql(self, expression: exp.ForIn) -> str: + this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + return f"FOR {this} DO {expression_sql}" + + def refresh_sql(self, expression: exp.Refresh) -> str: + this = self.sql(expression, "this") + table = "" if isinstance(expression.this, exp.Literal) else "TABLE " + return f"REFRESH {table}{this}" + + def operator_sql(self, expression: exp.Operator) -> str: + return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})") + def _simplify_unless_literal(self, expression: E) -> E: if not isinstance(expression, exp.Literal): from sqlglot.optimizer.simplify import simplify @@ -2970,3 +3055,10 @@ class Generator: expression = simplify(expression) return expression + + def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]: + return [ + exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string("")) + for value in values + if value + ] diff --git a/sqlglot/helper.py b/sqlglot/helper.py index ee41557..349c8c8 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import inspect import logging import re @@ -283,7 +284,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any: file = open_file(read_csv.name) delimiter = "," - args = iter(arg.name for arg in args) + args = iter(arg.name for arg in args) # type: ignore for k, v in zip(args, args): if k == "delimiter": delimiter = v @@ -463,3 +464,27 @@ def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: merged.append((start, end)) return merged + + +def is_iso_date(text: str) -> bool: + try: + datetime.date.fromisoformat(text) + return True + except ValueError: + return False + + +def is_iso_datetime(text: str) -> bool: + try: + datetime.datetime.fromisoformat(text) + return True + except ValueError: + return False + + +# Interval units that operate on date components +DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} + + +def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: + return expression is not None and expression.name.lower() in DATE_UNITS diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 011a6b8..abcc10f 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse from sqlglot.errors import SqlglotError -from sqlglot.optimizer import Scope, build_scope, qualify +from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -29,8 +29,38 @@ class Node: else: yield d - def to_html(self, **opts) -> LineageHTML: - return LineageHTML(self, **opts) + def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: + nodes = {} + edges = [] + + for node in self.walk(): + if isinstance(node.expression, exp.Table): + label = f"FROM {node.expression.this}" + title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" + group = 1 + else: + label = node.expression.sql(pretty=True, dialect=dialect) + source = node.source.transform( + lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") + if n is node.expression + else n, + copy=False, + ).sql(pretty=True, dialect=dialect) + title = f"<pre>{source}</pre>" + group = 0 + + node_id = id(node) + + nodes[node_id] = { + "id": node_id, + "label": label, + "title": title, + "group": group, + } + + for d in node.downstream: + edges.append({"from": node_id, "to": id(d)}) + return GraphHTML(nodes, edges, **opts) def lineage( @@ -64,6 +94,7 @@ def lineage( k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) for k, v in sources.items() }, + dialect=dialect, ) qualified = qualify.qualify( @@ -129,17 +160,6 @@ def lineage( return upstream - subquery = select.unalias() - - if isinstance(subquery, exp.Subquery): - upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select) - scope = t.cast(Scope, build_scope(subquery.unnest())) - - for select in subquery.named_selects: - to_node(select, scope=scope, upstream=upstream) - - return upstream - if isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with # a version that has only the column we care about. @@ -156,16 +176,28 @@ def lineage( expression=select, alias=alias or "", ) + if upstream: upstream.downstream.append(node) + subquery_scopes = { + id(subquery_scope.expression): subquery_scope + for subquery_scope in scope.subquery_scopes + } + + for subquery in find_all_in_scope(select, exp.Subqueryable): + subquery_scope = subquery_scopes[id(subquery)] + + for name in subquery.named_selects: + to_node(name, scope=subquery_scope, upstream=node) + # if the select is a star add all scope sources as downstreams if select.is_star: for source in scope.sources.values(): node.downstream.append(Node(name=select.sql(), source=source, expression=source)) # Find all columns that went into creating this one to list their lineage nodes. - source_columns = set(select.find_all(exp.Column)) + source_columns = set(find_all_in_scope(select, exp.Column)) # If the source is a UDTF find columns used in the UTDF to generate the table if isinstance(source, exp.UDTF): @@ -192,20 +224,15 @@ def lineage( return to_node(column if isinstance(column, str) else column.name, scope) -class LineageHTML: +class GraphHTML: """Node to HTML generator using vis.js. https://visjs.github.io/vis-network/docs/network/ """ def __init__( - self, - node: Node, - dialect: DialectType = None, - imports: bool = True, - **opts: t.Any, + self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None ): - self.node = node self.imports = imports self.options = { @@ -235,39 +262,11 @@ class LineageHTML: "maximum": 300, }, }, - **opts, + **(options or {}), } - self.nodes = {} - self.edges = [] - - for node in node.walk(): - if isinstance(node.expression, exp.Table): - label = f"FROM {node.expression.this}" - title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" - group = 1 - else: - label = node.expression.sql(pretty=True, dialect=dialect) - source = node.source.transform( - lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>") - if n is node.expression - else n, - copy=False, - ).sql(pretty=True, dialect=dialect) - title = f"<pre>{source}</pre>" - group = 0 - - node_id = id(node) - - self.nodes[node_id] = { - "id": node_id, - "label": label, - "title": title, - "group": group, - } - - for d in node.downstream: - self.edges.append({"from": node_id, "to": id(d)}) + self.nodes = nodes + self.edges = edges def __str__(self): nodes = json.dumps(list(self.nodes.values())) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 69d4567..7b990f1 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,12 +1,18 @@ from __future__ import annotations -import datetime import functools import typing as t from sqlglot import exp from sqlglot._typing import E -from sqlglot.helper import ensure_list, seq_get, subclasses +from sqlglot.helper import ( + ensure_list, + is_date_unit, + is_iso_date, + is_iso_datetime, + seq_get, + subclasses, +) from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema @@ -20,10 +26,6 @@ if t.TYPE_CHECKING: ] -# Interval units that operate on date components -DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} - - def annotate_types( expression: E, schema: t.Optional[t.Dict | Schema] = None, @@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type return lambda self, e: self._annotate_with_type(e, data_type) -def _is_iso_date(text: str) -> bool: - try: - datetime.date.fromisoformat(text) - return True - except ValueError: - return False - - -def _is_iso_datetime(text: str) -> bool: - try: - datetime.datetime.fromisoformat(text) - return True - except ValueError: - return False - - -def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: +def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: date_text = l.name - unit = r.text("unit").lower() - - is_iso_date = _is_iso_date(date_text) + is_iso_date_ = is_iso_date(date_text) - if is_iso_date and unit in DATE_UNITS: - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) + if is_iso_date_ and is_date_unit(unit): return exp.DataType.Type.DATE # An ISO date is also an ISO datetime, but not vice versa - if is_iso_date or _is_iso_datetime(date_text): - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) + if is_iso_date_ or is_iso_datetime(date_text): return exp.DataType.Type.DATETIME return exp.DataType.Type.UNKNOWN -def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: - unit = r.text("unit").lower() - if unit not in DATE_UNITS: +def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: + if not is_date_unit(unit): return exp.DataType.Type.DATETIME return l.type.this if l.type else exp.DataType.Type.UNKNOWN @@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Date, exp.DateFromParts, exp.DateStrToDate, - exp.DateTrunc, exp.DiToDate, exp.StrToDate, exp.TimeStrToDate, @@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DataType.Type.DOUBLE: { exp.ApproxQuantile, exp.Avg, + exp.Div, exp.Exp, exp.Ln, exp.Log, @@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): }, exp.DataType.Type.INT: { exp.Ceil, - exp.DateDiff, exp.DatetimeDiff, + exp.DateDiff, exp.Extract, exp.TimestampDiff, exp.TimeDiff, @@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.GroupConcat, exp.Initcap, exp.Lower, - exp.SafeConcat, - exp.SafeDPipe, exp.Substring, exp.TimeToStr, exp.TimeToTimeStr, @@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): for data_type, expressions in TYPE_TO_EXPRESSIONS.items() for expr_type in expressions }, + exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), @@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), - exp.DateAdd: lambda self, e: self._annotate_dateadd(e), - exp.DateSub: lambda self, e: self._annotate_dateadd(e), + exp.DateAdd: lambda self, e: self._annotate_timeunit(e), + exp.DateSub: lambda self, e: self._annotate_timeunit(e), + exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Div: lambda self, e: self._annotate_div(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), @@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), @@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator): BINARY_COERCIONS: BinaryCoercions = { **swap_all( { - (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval + (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( + l, r.args.get("unit") + ) for t in exp.DataType.TEXT_TYPES } ), **swap_all( { - (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, + # text + numeric will yield the numeric type to match most dialects' semantics + (text, numeric): lambda l, r: t.cast( + exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type + ) + for text in exp.DataType.TEXT_TYPES + for numeric in exp.DataType.NUMERIC_TYPES + } + ), + **swap_all( + { + (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( + l, r.args.get("unit") + ), } ), } @@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression - def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: + def _annotate_timeunit( + self, expression: exp.TimeUnit | exp.DateTrunc + ) -> exp.TimeUnit | exp.DateTrunc: self._annotate_args(expression) if expression.this.type.this in exp.DataType.TEXT_TYPES: - datatype = _coerce_literal_and_interval(expression.this, expression.interval()) - elif ( - expression.this.type.is_type(exp.DataType.Type.DATE) - and expression.text("unit").lower() not in DATE_UNITS - ): - datatype = exp.DataType.Type.DATETIME + datatype = _coerce_date_literal(expression.this, expression.unit) + elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: + datatype = _coerce_date(expression.this, expression.unit) else: - datatype = expression.this.type + datatype = exp.DataType.Type.UNKNOWN self._set_type(expression, datatype) return expression @@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, exp.DataType.Type.UNKNOWN) return expression + + def _annotate_div(self, expression: exp.Div) -> exp.Div: + self._annotate_args(expression) + + left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore + + if ( + expression.args.get("typed") + and left_type in exp.DataType.INTEGER_TYPES + and right_type in exp.DataType.INTEGER_TYPES + ): + self._set_type(expression, exp.DataType.Type.BIGINT) + else: + self._set_type(expression, self._maybe_coerce(left_type, right_type)) + + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index fc5c348..faf18c6 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,8 +1,10 @@ from __future__ import annotations import itertools +import typing as t from sqlglot import exp +from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime def canonicalize(expression: exp.Expression) -> exp.Expression: @@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression = replace_date_funcs(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) - expression = ensure_bool_predicates(expression) + expression = ensure_bools(expression, _replace_int_predicate) expression = remove_ascending_order(expression) return expression @@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression: return node +COERCIBLE_DATE_OPS = ( + exp.Add, + exp.Sub, + exp.EQ, + exp.NEQ, + exp.GT, + exp.GTE, + exp.LT, + exp.LTE, + exp.NullSafeEQ, + exp.NullSafeNEQ, +) + + def coerce_type(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Binary): + if isinstance(node, COERCIBLE_DATE_OPS): _coerce_date(node.left, node.right) elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) @@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression: *exp.DataType.TEMPORAL_TYPES ): _replace_cast(node.expression, exp.DataType.Type.DATETIME) + elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): + _coerce_timeunit_arg(node.this, node.unit) + elif isinstance(node, exp.DateDiff): + _coerce_datediff_args(node) return node @@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: return expression -def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: +def ensure_bools( + expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] +) -> exp.Expression: if isinstance(expression, exp.Connector): - _replace_int_predicate(expression.left) - _replace_int_predicate(expression.right) - - elif isinstance(expression, (exp.Where, exp.Having)) or ( + replace_func(expression.left) + replace_func(expression.right) + elif isinstance(expression, exp.Not): + replace_func(expression.this) # We can't replace num in CASE x WHEN num ..., because it's not the full predicate - isinstance(expression, exp.If) - and not (isinstance(expression.parent, exp.Case) and expression.parent.this) + elif isinstance(expression, exp.If) and not ( + isinstance(expression.parent, exp.Case) and expression.parent.this ): - _replace_int_predicate(expression.this) + replace_func(expression.this) + elif isinstance(expression, (exp.Where, exp.Having)): + replace_func(expression.this) return expression @@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression: def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): + if isinstance(b, exp.Interval): + a = _coerce_timeunit_arg(a, b.unit) if ( a.type and a.type.this == exp.DataType.Type.DATE and b.type - and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) + and b.type.this + not in ( + exp.DataType.Type.DATE, + exp.DataType.Type.INTERVAL, + ) ): _replace_cast(b, exp.DataType.Type.DATE) +def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: + if not arg.type: + return arg + + if arg.type.this in exp.DataType.TEXT_TYPES: + date_text = arg.name + is_iso_date_ = is_iso_date(date_text) + + if is_iso_date_ and is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date_ or is_iso_datetime(date_text): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + return arg + + +def _coerce_datediff_args(node: exp.DateDiff) -> None: + for e in (node.this, node.expression): + if e.type.this not in exp.DataType.TEMPORAL_TYPES: + e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) + + def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: node.replace(exp.cast(node.copy(), to=to)) +# this was originally designed for presto, there is a similar transform for tsql +# this is different in that it only operates on int types, this is because +# presto has a boolean type whereas tsql doesn't (people use bits) +# with y as (select true as x) select x = 0 FROM y -- illegal presto query def _replace_int_predicate(expression: exp.Expression) -> None: if isinstance(expression, exp.Coalesce): for _, child in expression.iter_expressions(): _replace_int_predicate(child) elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: - expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) + expression.replace(expression.neq(0)) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index b0b2b3d..a74bea7 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -186,13 +186,13 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): and not ( isinstance(from_or_join, exp.Join) and inner_select.args.get("where") - and from_or_join.side in {"FULL", "LEFT", "RIGHT"} + and from_or_join.side in ("FULL", "LEFT", "RIGHT") ) and not ( isinstance(from_or_join, exp.From) and inner_select.args.get("where") and any( - j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) + j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", []) ) ) and not _outer_select_joins_on_inner_select_join() diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 154256e..3361a33 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -13,7 +13,7 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression: +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... @@ -48,11 +48,11 @@ def normalize_identifiers(expression, dialect=None): Returns: The transformed expression. """ + dialect = Dialect.get_or_raise(dialect) + if isinstance(expression, str): expression = exp.parse_identifier(expression, dialect=dialect) - dialect = Dialect.get_or_raise(dialect) - def _normalize(node: E) -> E: if not node.meta.get("case_sensitive"): exp.replace_children(node, _normalize) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index abac63b..1c96e95 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -42,8 +42,8 @@ RULES = ( def optimize( expression: str | exp.Expression, schema: t.Optional[dict | Schema] = None, - db: t.Optional[str] = None, - catalog: t.Optional[str] = None, + db: t.Optional[str | exp.Identifier] = None, + catalog: t.Optional[str | exp.Identifier] = None, dialect: DialectType = None, rules: t.Sequence[t.Callable] = RULES, **kwargs, diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index b06ea1d..742cdf5 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -8,7 +8,7 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get -from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -58,7 +58,7 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables, pseudocolumns) - _qualify_outputs(scope) + qualify_outputs(scope) _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None: ordereds = order.expressions for ordered, new_expression in zip( ordereds, - _expand_positional_references(scope, (o.this for o in ordereds)), + _expand_positional_references(scope, (o.this for o in ordereds), alias=True), ): for agg in ordered.find_all(exp.AggFunc): for col in agg.find_all(exp.Column): @@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None: ) -def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: - new_nodes = [] +def _expand_positional_references( + scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False +) -> t.List[exp.Expression]: + new_nodes: t.List[exp.Expression] = [] for node in expressions: if node.is_int: - select = _select_by_pos(scope, t.cast(exp.Literal, node)).this + select = _select_by_pos(scope, t.cast(exp.Literal, node)) - if isinstance(select, exp.Literal): - new_nodes.append(node) + if alias: + new_nodes.append(exp.column(select.args["alias"].copy())) else: - new_nodes.append(select.copy()) - scope.clear_cache() + select = select.this + + if isinstance(select, exp.Literal): + new_nodes.append(node) + else: + new_nodes.append(select.copy()) else: new_nodes.append(node) @@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None: if column_table: column.set("table", column_table) elif column_table not in scope.sources and ( - not scope.parent or column_table not in scope.parent.sources + not scope.parent + or column_table not in scope.parent.sources + or not scope.is_correlated_subquery ): # structs are used like tables (e.g. "struct"."field"), so they need to be qualified # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) @@ -381,15 +389,18 @@ def _expand_stars( columns = [name for name in columns if name.upper() not in pseudocolumns] if columns and "*" not in columns: + table_id = id(table) + columns_to_exclude = except_columns.get(table_id) or set() + if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: implicit_columns = [col for col in columns if col not in pivot_columns] new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) for name in implicit_columns + pivot_output_columns + if name not in columns_to_exclude ) continue - table_id = id(table) for name in columns: if name in using_column_tables and table in using_column_tables[name]: if name in coalesced_columns: @@ -406,7 +417,7 @@ def _expand_stars( copy=False, ) ) - elif name not in except_columns.get(table_id, set()): + elif name not in columns_to_exclude: alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table=table) new_selections.append( @@ -448,10 +459,16 @@ def _add_replace_columns( replace_columns[id(table)] = columns -def _qualify_outputs(scope: Scope) -> None: +def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: """Ensure all output columns are aliased""" - new_selections = [] + if isinstance(scope_or_expression, exp.Expression): + scope = build_scope(scope_or_expression) + if not isinstance(scope, Scope): + return + else: + scope = scope_or_expression + new_selections = [] for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 3a43e8f..57ecabe 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import itertools import typing as t from sqlglot import alias, exp from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema @@ -10,9 +13,10 @@ from sqlglot.schema import Schema def qualify_tables( expression: E, - db: t.Optional[str] = None, - catalog: t.Optional[str] = None, + db: t.Optional[str | exp.Identifier] = None, + catalog: t.Optional[str | exp.Identifier] = None, schema: t.Optional[Schema] = None, + dialect: DialectType = None, ) -> E: """ Rewrite sqlglot AST to have fully qualified tables. Join constructs such as @@ -33,11 +37,14 @@ def qualify_tables( db: Database name catalog: Catalog name schema: A schema to populate + dialect: The dialect to parse catalog and schema into. Returns: The qualified expression. """ next_alias_name = name_sequence("_q_") + db = exp.parse_identifier(db, dialect=dialect) if db else None + catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): @@ -61,9 +68,9 @@ def qualify_tables( if isinstance(source, exp.Table): if isinstance(source.this, exp.Identifier): if not source.args.get("db"): - source.set("db", exp.to_identifier(db)) + source.set("db", db) if not source.args.get("catalog") and source.args.get("db"): - source.set("catalog", exp.to_identifier(catalog)) + source.set("catalog", catalog) if not source.alias: # Mutates the source by attaching an alias to it diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 4af5b49..b7e527e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import logging import typing as t diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index af03332..d4e2e60 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -507,6 +507,9 @@ def simplify_literals(expression, root=True): return exp.Literal.number(value[1:]) return exp.Literal.number(f"-{value}") + if type(expression) in INVERSE_DATE_OPS: + return _simplify_binary(expression, expression.this, expression.interval()) or expression + return expression @@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b): return exp.null() if a.is_number and b.is_number: - a = int(a.name) if a.is_int else Decimal(a.name) - b = int(b.name) if b.is_int else Decimal(b.name) + num_a = int(a.name) if a.is_int else Decimal(a.name) + num_b = int(b.name) if b.is_int else Decimal(b.name) if isinstance(expression, exp.Add): - return exp.Literal.number(a + b) - if isinstance(expression, exp.Sub): - return exp.Literal.number(a - b) + return exp.Literal.number(num_a + num_b) if isinstance(expression, exp.Mul): - return exp.Literal.number(a * b) + return exp.Literal.number(num_a * num_b) + + # We only simplify Sub, Div if a and b have the same parent because they're not associative + if isinstance(expression, exp.Sub): + return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None if isinstance(expression, exp.Div): # engines have differing int div behavior so intdiv is not safe - if isinstance(a, int) and isinstance(b, int): + if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: return None - return exp.Literal.number(a / b) + return exp.Literal.number(num_a / num_b) - boolean = eval_boolean(expression, a, b) + boolean = eval_boolean(expression, num_a, num_b) if boolean: return boolean @@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b): elif _is_date_literal(a) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) if a and b: - if isinstance(expression, exp.Add): + if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): return date_literal(a + b) - if isinstance(expression, exp.Sub): + if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): return date_literal(a - b) elif isinstance(a, exp.Interval) and _is_date_literal(b): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): return date_literal(a + b) + elif _is_date_literal(a) and _is_date_literal(b): + if isinstance(expression, exp.Predicate): + a, b = extract_date(a), extract_date(b) + boolean = eval_boolean(expression, a, b) + if boolean: + return boolean return None @@ -590,6 +601,11 @@ def simplify_parens(expression): return expression +NONNULL_CONSTANTS = ( + exp.Literal, + exp.Boolean, +) + CONSTANTS = ( exp.Literal, exp.Boolean, @@ -597,11 +613,19 @@ CONSTANTS = ( ) +def _is_nonnull_constant(expression: exp.Expression) -> bool: + return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) + + +def _is_constant(expression: exp.Expression) -> bool: + return isinstance(expression, CONSTANTS) or _is_date_literal(expression) + + def simplify_coalesce(expression): # COALESCE(x) -> x if ( isinstance(expression, exp.Coalesce) - and not expression.expressions + and (not expression.expressions or _is_nonnull_constant(expression.this)) # COALESCE is also used as a Spark partitioning hint and not isinstance(expression.parent, exp.Hint) ): @@ -621,12 +645,12 @@ def simplify_coalesce(expression): # This transformation is valid for non-constants, # but it really only does anything if they are both constants. - if not isinstance(other, CONSTANTS): + if not _is_constant(other): return expression # Find the first constant arg for arg_index, arg in enumerate(coalesce.expressions): - if isinstance(arg, CONSTANTS): + if _is_constant(other): break else: return expression @@ -656,7 +680,6 @@ def simplify_coalesce(expression): CONCATS = (exp.Concat, exp.DPipe) -SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) def simplify_concat(expression): @@ -672,10 +695,15 @@ def simplify_concat(expression): sep_expr, *expressions = expression.expressions sep = sep_expr.name concat_type = exp.ConcatWs + args = {} else: expressions = expression.expressions sep = "" - concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + concat_type = exp.Concat + args = { + "safe": expression.args.get("safe"), + "coalesce": expression.args.get("coalesce"), + } new_args = [] for is_string_group, group in itertools.groupby( @@ -692,7 +720,7 @@ def simplify_concat(expression): if concat_type is exp.ConcatWs: new_args = [sep_expr] + new_args - return concat_type(expressions=new_args) + return concat_type(expressions=new_args, **args) def simplify_conditionals(expression): @@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: if isinstance(cast, exp.Cast): to = cast.to - elif isinstance(cast, exp.TsOrDsToDate): + elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): to = exp.DataType.build(exp.DataType.Type.DATE) else: return None @@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool: def extract_interval(expression): - n = int(expression.name) - unit = expression.text("unit").lower() - try: + n = int(expression.name) + unit = expression.text("unit").lower() return interval(unit, n) - except (UnsupportedUnit, ModuleNotFoundError): + except (UnsupportedUnit, ModuleNotFoundError, ValueError): return None @@ -1099,8 +1126,6 @@ GEN_MAP = { exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", exp.Div: lambda e: _binary(e, "/"), exp.Dot: lambda e: _binary(e, "."), - exp.DPipe: lambda e: _binary(e, "||"), - exp.SafeDPipe: lambda e: _binary(e, "||"), exp.EQ: lambda e: _binary(e, "="), exp.GT: lambda e: _binary(e, ">"), exp.GTE: lambda e: _binary(e, ">="), diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 1dab600..c7e27a3 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -13,6 +13,7 @@ from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: from sqlglot._typing import E + from sqlglot.dialects.dialect import Dialect, DialectType logger = logging.getLogger("sqlglot") @@ -46,6 +47,19 @@ def binary_range_parser( ) +def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func: + # Default argument order is base, expression + this = seq_get(args, 0) + expression = seq_get(args, 1) + + if expression: + if not dialect.LOG_BASE_FIRST: + this, expression = expression, this + return exp.Log(this=this, expression=expression) + + return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) + + class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -72,13 +86,24 @@ class Parser(metaclass=_Parser): """ FUNCTIONS: t.Dict[str, t.Callable] = { - **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, + **{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()}, + "CONCAT": lambda args, dialect: exp.Concat( + expressions=args, + safe=not dialect.STRICT_STRING_CONCAT, + coalesce=dialect.CONCAT_COALESCE, + ), + "CONCAT_WS": lambda args, dialect: exp.ConcatWs( + expressions=args, + safe=not dialect.STRICT_STRING_CONCAT, + coalesce=dialect.CONCAT_COALESCE, + ), "DATE_TO_DATE_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), "LIKE": parse_like, + "LOG": parse_logarithm, "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), @@ -229,7 +254,7 @@ class Parser(metaclass=_Parser): TokenType.SOME: exp.Any, } - RESERVED_KEYWORDS = { + RESERVED_TOKENS = { *Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT, } @@ -245,9 +270,11 @@ class Parser(metaclass=_Parser): CREATABLES = { TokenType.COLUMN, + TokenType.CONSTRAINT, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE, + TokenType.FOREIGN_KEY, *DB_CREATABLES, } @@ -291,6 +318,7 @@ class Parser(metaclass=_Parser): TokenType.NATURAL, TokenType.NEXT, TokenType.OFFSET, + TokenType.OPERATOR, TokenType.ORDINALITY, TokenType.OVERLAPS, TokenType.OVERWRITE, @@ -299,7 +327,10 @@ class Parser(metaclass=_Parser): TokenType.PIVOT, TokenType.PRAGMA, TokenType.RANGE, + TokenType.RECURSIVE, TokenType.REFERENCES, + TokenType.REFRESH, + TokenType.REPLACE, TokenType.RIGHT, TokenType.ROW, TokenType.ROWS, @@ -390,6 +421,7 @@ class Parser(metaclass=_Parser): } EQUALITY = { + TokenType.COLON_EQ: exp.PropertyEQ, TokenType.EQ: exp.EQ, TokenType.NEQ: exp.NEQ, TokenType.NULLSAFE_EQ: exp.NullSafeEQ, @@ -406,7 +438,6 @@ class Parser(metaclass=_Parser): TokenType.AMP: exp.BitwiseAnd, TokenType.CARET: exp.BitwiseXor, TokenType.PIPE: exp.BitwiseOr, - TokenType.DPIPE: exp.DPipe, } TERM = { @@ -423,6 +454,8 @@ class Parser(metaclass=_Parser): TokenType.STAR: exp.Mul, } + EXPONENT: t.Dict[TokenType, t.Type[exp.Expression]] = {} + TIMES = { TokenType.TIME, TokenType.TIMETZ, @@ -558,6 +591,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE: lambda self: self._parse_merge(), TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()), + TokenType.REFRESH: lambda self: self._parse_refresh(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.SET: lambda self: self._parse_set(), TokenType.UNCACHE: lambda self: self._parse_uncache(), @@ -697,6 +731,7 @@ class Parser(metaclass=_Parser): exp.StabilityProperty, this=exp.Literal.string("STABLE") ), "STORED": lambda self: self._parse_stored(), + "SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), "TEMP": lambda self: self.expression(exp.TemporaryProperty), "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), @@ -754,6 +789,7 @@ class Parser(metaclass=_Parser): ) or self.expression(exp.OnProperty, this=self._parse_id_var()), "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), + "PERIOD": lambda self: self._parse_period_for_system_time(), "PRIMARY KEY": lambda self: self._parse_primary_key(), "REFERENCES": lambda self: self._parse_references(match=False), "TITLE": lambda self: self.expression( @@ -775,7 +811,7 @@ class Parser(metaclass=_Parser): "RENAME": lambda self: self._parse_alter_table_rename(), } - SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} + SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"} NO_PAREN_FUNCTION_PARSERS = { "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()), @@ -794,14 +830,11 @@ class Parser(metaclass=_Parser): FUNCTION_PARSERS = { "ANY_VALUE": lambda self: self._parse_any_value(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), - "CONCAT": lambda self: self._parse_concat(), - "CONCAT_WS": lambda self: self._parse_concat_ws(), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), "JSON_TABLE": lambda self: self._parse_json_table(), - "LOG": lambda self: self._parse_logarithm(), "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), @@ -877,6 +910,7 @@ class Parser(metaclass=_Parser): CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"} OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"} + OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN} TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE} @@ -896,17 +930,13 @@ class Parser(metaclass=_Parser): STRICT_CAST = True - # A NULL arg in CONCAT yields NULL by default - CONCAT_NULL_OUTPUTS_STRING = False - PREFIXED_PIVOT_COLUMNS = False IDENTIFY_PIVOT_STRINGS = False - LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False # Whether or not ADD is present for each column added by ALTER TABLE - ALTER_TABLE_ADD_COLUMN_KEYWORD = True + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True # Whether or not the table sample clause expects CSV syntax TABLESAMPLE_CSV = False @@ -921,6 +951,7 @@ class Parser(metaclass=_Parser): "error_level", "error_message_context", "max_errors", + "dialect", "sql", "errors", "_tokens", @@ -929,35 +960,25 @@ class Parser(metaclass=_Parser): "_next", "_prev", "_prev_comments", - "_tokenizer", ) # Autofilled - TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer - INDEX_OFFSET: int = 0 - UNNEST_COLUMN_ONLY: bool = False - ALIAS_POST_TABLESAMPLE: bool = False - STRICT_STRING_CONCAT = False - SUPPORTS_USER_DEFINED_TYPES = True - NORMALIZE_FUNCTIONS = "upper" - NULL_ORDERING: str = "nulls_are_small" SHOW_TRIE: t.Dict = {} SET_TRIE: t.Dict = {} - FORMAT_MAPPING: t.Dict[str, str] = {} - FORMAT_TRIE: t.Dict = {} - TIME_MAPPING: t.Dict[str, str] = {} - TIME_TRIE: t.Dict = {} def __init__( self, error_level: t.Optional[ErrorLevel] = None, error_message_context: int = 100, max_errors: int = 3, + dialect: DialectType = None, ): + from sqlglot.dialects import Dialect + self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context self.max_errors = max_errors - self._tokenizer = self.TOKENIZER_CLASS() + self.dialect = Dialect.get_or_raise(dialect) self.reset() def reset(self): @@ -1384,7 +1405,7 @@ class Parser(metaclass=_Parser): if self._match_texts(self.CLONE_KEYWORDS): copy = self._prev.text.lower() == "copy" clone = self._parse_table(schema=True) - when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper() + when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper() clone_kind = ( self._match(TokenType.L_PAREN) and self._match_texts(self.CLONE_KINDS) @@ -1524,6 +1545,22 @@ class Parser(metaclass=_Parser): return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) + def _parse_system_versioning_property(self) -> exp.WithSystemVersioningProperty: + self._match_pair(TokenType.EQ, TokenType.ON) + + prop = self.expression(exp.WithSystemVersioningProperty) + if self._match(TokenType.L_PAREN): + self._match_text_seq("HISTORY_TABLE", "=") + prop.set("this", self._parse_table_parts()) + + if self._match(TokenType.COMMA): + self._match_text_seq("DATA_CONSISTENCY_CHECK", "=") + prop.set("expression", self._advance_any() and self._prev.text.upper()) + + self._match_r_paren() + + return prop + def _parse_with_property( self, ) -> t.Optional[exp.Expression] | t.List[exp.Expression]: @@ -2140,7 +2177,11 @@ class Parser(metaclass=_Parser): return self._parse_expressions() def _parse_select( - self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True + self, + nested: bool = False, + table: bool = False, + parse_subquery_alias: bool = True, + parse_set_operation: bool = True, ) -> t.Optional[exp.Expression]: cte = self._parse_with() @@ -2216,7 +2257,11 @@ class Parser(metaclass=_Parser): t.cast(exp.From, self._parse_from(skip_from_token=True)) ) else: - this = self._parse_table() if table else self._parse_select(nested=True) + this = ( + self._parse_table() + if table + else self._parse_select(nested=True, parse_set_operation=False) + ) this = self._parse_set_operations(self._parse_query_modifiers(this)) self._match_r_paren() @@ -2235,7 +2280,9 @@ class Parser(metaclass=_Parser): else: this = None - return self._parse_set_operations(this) + if parse_set_operation: + return self._parse_set_operations(this) + return this def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]: if not skip_with_token and not self._match(TokenType.WITH): @@ -2563,9 +2610,8 @@ class Parser(metaclass=_Parser): if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False): return this - opclass = self._parse_var(any_token=True) - if opclass: - return self.expression(exp.Opclass, this=this, expression=opclass) + if not self._match_set(self.OPTYPE_FOLLOW_TOKENS, advance=False): + return self.expression(exp.Opclass, this=this, expression=self._parse_table_parts()) return this @@ -2630,7 +2676,7 @@ class Parser(metaclass=_Parser): while self._match_set(self.TABLE_INDEX_HINT_TOKENS): hint = exp.IndexTableHint(this=self._prev.text.upper()) - self._match_texts({"INDEX", "KEY"}) + self._match_texts(("INDEX", "KEY")) if self._match(TokenType.FOR): hint.set("target", self._advance_any() and self._prev.text.upper()) @@ -2650,7 +2696,7 @@ class Parser(metaclass=_Parser): def _parse_table_parts(self, schema: bool = False) -> exp.Table: catalog = None db = None - table = self._parse_table_part(schema=schema) + table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) while self._match(TokenType.DOT): if catalog: @@ -2661,7 +2707,7 @@ class Parser(metaclass=_Parser): else: catalog = db db = table - table = self._parse_table_part(schema=schema) + table = self._parse_table_part(schema=schema) or "" if not table: self.raise_error(f"Expected table name but got {self._curr}") @@ -2709,7 +2755,7 @@ class Parser(metaclass=_Parser): if version: this.set("version", version) - if self.ALIAS_POST_TABLESAMPLE: + if self.dialect.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) @@ -2724,7 +2770,7 @@ class Parser(metaclass=_Parser): if not this.args.get("pivots"): this.set("pivots", self._parse_pivots()) - if not self.ALIAS_POST_TABLESAMPLE: + if not self.dialect.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() if table_sample: @@ -2776,13 +2822,13 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.UNNEST): return None - expressions = self._parse_wrapped_csv(self._parse_type) + expressions = self._parse_wrapped_csv(self._parse_equality) offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) alias = self._parse_table_alias() if with_alias else None if alias: - if self.UNNEST_COLUMN_ONLY: + if self.dialect.UNNEST_COLUMN_ONLY: if alias.args.get("columns"): self.raise_error("Unexpected extra column alias in unnest.") @@ -2845,7 +2891,7 @@ class Parser(metaclass=_Parser): num = ( self._parse_factor() if self._match(TokenType.NUMBER, advance=False) - else self._parse_primary() + else self._parse_primary() or self._parse_placeholder() ) if self._match_text_seq("BUCKET"): @@ -3108,10 +3154,10 @@ class Parser(metaclass=_Parser): if ( not explicitly_null_ordered and ( - (not desc and self.NULL_ORDERING == "nulls_are_small") - or (desc and self.NULL_ORDERING != "nulls_are_small") + (not desc and self.dialect.NULL_ORDERING == "nulls_are_small") + or (desc and self.dialect.NULL_ORDERING != "nulls_are_small") ) - and self.NULL_ORDERING != "nulls_are_last" + and self.dialect.NULL_ORDERING != "nulls_are_last" ): nulls_first = True @@ -3124,7 +3170,7 @@ class Parser(metaclass=_Parser): comments = self._prev_comments if top: limit_paren = self._match(TokenType.L_PAREN) - expression = self._parse_number() + expression = self._parse_term() if limit_paren else self._parse_number() if limit_paren: self._match_r_paren() @@ -3225,7 +3271,9 @@ class Parser(metaclass=_Parser): this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), by_name=self._match_text_seq("BY", "NAME"), - expression=self._parse_set_operations(self._parse_select(nested=True)), + expression=self._parse_set_operations( + self._parse_select(nested=True, parse_set_operation=False) + ), ) def _parse_expression(self) -> t.Optional[exp.Expression]: @@ -3287,7 +3335,8 @@ class Parser(metaclass=_Parser): unnest = self._parse_unnest(with_alias=False) if unnest: this = self.expression(exp.In, this=this, unnest=unnest) - elif self._match(TokenType.L_PAREN): + elif self._match_set((TokenType.L_PAREN, TokenType.L_BRACKET)): + matched_l_paren = self._prev.token_type == TokenType.L_PAREN expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): @@ -3295,13 +3344,16 @@ class Parser(metaclass=_Parser): else: this = self.expression(exp.In, this=this, expressions=expressions) - self._match_r_paren(this) + if matched_l_paren: + self._match_r_paren(this) + elif not self._match(TokenType.R_BRACKET, expression=this): + self.raise_error("Expecting ]") else: this = self.expression(exp.In, this=this, field=self._parse_field()) return this - def _parse_between(self, this: exp.Expression) -> exp.Between: + def _parse_between(self, this: t.Optional[exp.Expression]) -> exp.Between: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() @@ -3357,6 +3409,13 @@ class Parser(metaclass=_Parser): this=this, expression=self._parse_term(), ) + elif self.dialect.DPIPE_IS_STRING_CONCAT and self._match(TokenType.DPIPE): + this = self.expression( + exp.DPipe, + this=this, + expression=self._parse_term(), + safe=not self.dialect.STRICT_STRING_CONCAT, + ) elif self._match(TokenType.DQMARK): this = self.expression(exp.Coalesce, this=this, expressions=self._parse_term()) elif self._match_pair(TokenType.LT, TokenType.LT): @@ -3376,7 +3435,17 @@ class Parser(metaclass=_Parser): return self._parse_tokens(self._parse_factor, self.TERM) def _parse_factor(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_unary, self.FACTOR) + if self.EXPONENT: + factor = self._parse_tokens(self._parse_exponent, self.FACTOR) + else: + factor = self._parse_tokens(self._parse_unary, self.FACTOR) + if isinstance(factor, exp.Div): + factor.args["typed"] = self.dialect.TYPED_DIVISION + factor.args["safe"] = self.dialect.SAFE_DIVISION + return factor + + def _parse_exponent(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_unary, self.EXPONENT) def _parse_unary(self) -> t.Optional[exp.Expression]: if self._match_set(self.UNARY_PARSERS): @@ -3427,14 +3496,14 @@ class Parser(metaclass=_Parser): ) if identifier: - tokens = self._tokenizer.tokenize(identifier.name) + tokens = self.dialect.tokenize(identifier.name) if len(tokens) != 1: self.raise_error("Unexpected identifier", self._prev) if tokens[0].token_type in self.TYPE_TOKENS: self._prev = tokens[0] - elif self.SUPPORTS_USER_DEFINED_TYPES: + elif self.dialect.SUPPORTS_USER_DEFINED_TYPES: type_name = identifier.name while self._match(TokenType.DOT): @@ -3713,6 +3782,7 @@ class Parser(metaclass=_Parser): if not self._curr: return None + comments = self._curr.comments token_type = self._curr.token_type this = self._curr.text upper = this.upper() @@ -3754,13 +3824,22 @@ class Parser(metaclass=_Parser): args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) if function and not anonymous: - func = self.validate_expression(function(args), args) - if not self.NORMALIZE_FUNCTIONS: + if "dialect" in function.__code__.co_varnames: + func = function(args, dialect=self.dialect) + else: + func = function(args) + + func = self.validate_expression(func, args) + if not self.dialect.NORMALIZE_FUNCTIONS: func.meta["name"] = this + this = func else: this = self.expression(exp.Anonymous, this=this, expressions=args) + if isinstance(this, exp.Expression): + this.add_comments(comments) + self._match_r_paren(this) return self._parse_window(this) @@ -3875,6 +3954,11 @@ class Parser(metaclass=_Parser): not_null=self._match_pair(TokenType.NOT, TokenType.NULL), ) ) + elif kind and self._match_pair(TokenType.ALIAS, TokenType.L_PAREN, advance=False): + self._match(TokenType.ALIAS) + constraints.append( + self.expression(exp.TransformColumnConstraint, this=self._parse_field()) + ) while True: constraint = self._parse_column_constraint() @@ -3917,7 +4001,11 @@ class Parser(metaclass=_Parser): def _parse_generated_as_identity( self, - ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint: + ) -> ( + exp.GeneratedAsIdentityColumnConstraint + | exp.ComputedColumnConstraint + | exp.GeneratedAsRowColumnConstraint + ): if self._match_text_seq("BY", "DEFAULT"): on_null = self._match_pair(TokenType.ON, TokenType.NULL) this = self.expression( @@ -3928,6 +4016,14 @@ class Parser(metaclass=_Parser): this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) self._match(TokenType.ALIAS) + + if self._match_text_seq("ROW"): + start = self._match_text_seq("START") + if not start: + self._match(TokenType.END) + hidden = self._match_text_seq("HIDDEN") + return self.expression(exp.GeneratedAsRowColumnConstraint, start=start, hidden=hidden) + identity = self._match_text_seq("IDENTITY") if self._match(TokenType.L_PAREN): @@ -4100,6 +4196,16 @@ class Parser(metaclass=_Parser): def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: return self._parse_field() + def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint: + self._match(TokenType.TIMESTAMP_SNAPSHOT) + + id_vars = self._parse_wrapped_id_vars() + return self.expression( + exp.PeriodForSystemTimeConstraint, + this=seq_get(id_vars, 0), + expression=seq_get(id_vars, 1), + ) + def _parse_primary_key( self, wrapped_optional: bool = False, in_props: bool = False ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: @@ -4145,7 +4251,7 @@ class Parser(metaclass=_Parser): elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: - expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET) + expressions = apply_index_offset(this, expressions, -self.dialect.INDEX_OFFSET) this = self.expression(exp.Bracket, this=this, expressions=expressions) self._add_comments(this) @@ -4259,8 +4365,8 @@ class Parser(metaclass=_Parser): format=exp.Literal.string( format_time( fmt_string.this if fmt_string else "", - self.FORMAT_MAPPING or self.TIME_MAPPING, - self.FORMAT_TRIE or self.TIME_TRIE, + self.dialect.FORMAT_MAPPING or self.dialect.TIME_MAPPING, + self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE, ) ), ) @@ -4280,30 +4386,6 @@ class Parser(metaclass=_Parser): exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe ) - def _parse_concat(self) -> t.Optional[exp.Expression]: - args = self._parse_csv(self._parse_conjunction) - if self.CONCAT_NULL_OUTPUTS_STRING: - args = self._ensure_string_if_null(args) - - # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when - # we find such a call we replace it with its argument. - if len(args) == 1: - return args[0] - - return self.expression( - exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args - ) - - def _parse_concat_ws(self) -> t.Optional[exp.Expression]: - args = self._parse_csv(self._parse_conjunction) - if len(args) < 2: - return self.expression(exp.ConcatWs, expressions=args) - delim, *values = args - if self.CONCAT_NULL_OUTPUTS_STRING: - values = self._ensure_string_if_null(values) - - return self.expression(exp.ConcatWs, expressions=[delim] + values) - def _parse_string_agg(self) -> exp.Expression: if self._match(TokenType.DISTINCT): args: t.List[t.Optional[exp.Expression]] = [ @@ -4495,19 +4577,6 @@ class Parser(metaclass=_Parser): empty_handling=empty_handling, ) - def _parse_logarithm(self) -> exp.Func: - # Default argument order is base, expression - args = self._parse_csv(self._parse_range) - - if len(args) > 1: - if not self.LOG_BASE_FIRST: - args.reverse() - return exp.Log.from_arg_list(args) - - return self.expression( - exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) - ) - def _parse_match_against(self) -> exp.MatchAgainst: expressions = self._parse_csv(self._parse_column) @@ -4755,6 +4824,7 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression], explicit: bool = False ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) + comments = self._prev_comments if explicit and not any_token: return this @@ -4762,6 +4832,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.L_PAREN): aliases = self.expression( exp.Aliases, + comments=comments, this=this, expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), ) @@ -4771,7 +4842,7 @@ class Parser(metaclass=_Parser): alias = self._parse_id_var(any_token) if alias: - return self.expression(exp.Alias, this=this, alias=alias) + return self.expression(exp.Alias, comments=comments, this=this, alias=alias) return this @@ -4792,8 +4863,8 @@ class Parser(metaclass=_Parser): return None def _parse_string(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.STRING): - return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) + if self._match_set((TokenType.STRING, TokenType.RAW_STRING)): + return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) return self._parse_placeholder() def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: @@ -4821,7 +4892,7 @@ class Parser(metaclass=_Parser): return self._parse_placeholder() def _advance_any(self) -> t.Optional[Token]: - if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: + if self._curr and self._curr.token_type not in self.RESERVED_TOKENS: self._advance() return self._prev return None @@ -4951,7 +5022,7 @@ class Parser(metaclass=_Parser): if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text - self._match_texts({"TRANSACTION", "WORK"}) + self._match_texts(("TRANSACTION", "WORK")) modes = [] while True: @@ -4971,7 +5042,7 @@ class Parser(metaclass=_Parser): savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK - self._match_texts({"TRANSACTION", "WORK"}) + self._match_texts(("TRANSACTION", "WORK")) if self._match_text_seq("TO"): self._match_text_seq("SAVEPOINT") @@ -4986,6 +5057,10 @@ class Parser(metaclass=_Parser): return self.expression(exp.Commit, chain=chain) + def _parse_refresh(self) -> exp.Refresh: + self._match(TokenType.TABLE) + return self.expression(exp.Refresh, this=self._parse_string() or self._parse_table()) + def _parse_add_column(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("ADD"): return None @@ -5050,10 +5125,9 @@ class Parser(metaclass=_Parser): return self._parse_csv(self._parse_add_constraint) self._retreat(index) - if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"): - return self._parse_csv(self._parse_field_def) - - return self._parse_csv(self._parse_add_column) + if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"): + return self._parse_wrapped_csv(self._parse_field_def, optional=True) + return self._parse_wrapped_csv(self._parse_add_column, optional=True) def _parse_alter_table_alter(self) -> exp.AlterColumn: self._match(TokenType.COLUMN) @@ -5198,7 +5272,7 @@ class Parser(metaclass=_Parser): ) -> t.Optional[exp.Expression]: index = self._index - if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): + if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"): return self._parse_set_transaction(global_=kind == "GLOBAL") left = self._parse_primary() or self._parse_id_var() @@ -5292,7 +5366,9 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.DictRange, this=this, min=min, max=max) - def _parse_comprehension(self, this: exp.Expression) -> t.Optional[exp.Comprehension]: + def _parse_comprehension( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Comprehension]: index = self._index expression = self._parse_column() if not self._match(TokenType.IN): @@ -5441,10 +5517,3 @@ class Parser(metaclass=_Parser): else: column.replace(dot_or_id) return node - - def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]: - return [ - exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string("")) - for value in values - if value - ] diff --git a/sqlglot/schema.py b/sqlglot/schema.py index acf9bc4..54c08dd 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -15,8 +15,6 @@ if t.TYPE_CHECKING: ColumnMapping = t.Union[t.Dict, str, StructType, t.List] -TABLE_ARGS = ("this", "db", "catalog") - class Schema(abc.ABC): """Abstract base class for database schemas""" @@ -147,7 +145,7 @@ class AbstractMappingSchema: if not depth: # None self._supported_table_args = tuple() elif 1 <= depth <= 3: - self._supported_table_args = TABLE_ARGS[:depth] + self._supported_table_args = exp.TABLE_PARTS[:depth] else: raise SchemaError(f"Invalid mapping shape. Depth: {depth}") @@ -156,7 +154,7 @@ class AbstractMappingSchema: def table_parts(self, table: exp.Table) -> t.List[str]: if isinstance(table.this, exp.ReadCSV): return [table.this.name] - return [table.text(part) for part in TABLE_ARGS if table.text(part)] + return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] def find( self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True @@ -365,13 +363,11 @@ class MappingSchema(AbstractMappingSchema, Schema): f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." ) - normalized_keys = [ - self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys - ] + normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] for column_name, column_type in columns.items(): nested_set( normalized_mapping, - normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], + normalized_keys + [self._normalize_name(column_name)], column_type, ) @@ -383,21 +379,19 @@ class MappingSchema(AbstractMappingSchema, Schema): dialect: DialectType = None, normalize: t.Optional[bool] = None, ) -> exp.Table: - normalized_table = exp.maybe_parse( - table, into=exp.Table, dialect=dialect or self.dialect, copy=True - ) + dialect = dialect or self.dialect + normalize = self.normalize if normalize is None else normalize - for arg in TABLE_ARGS: - value = normalized_table.args.get(arg) - if isinstance(value, (str, exp.Identifier)): - normalized_table.set( - arg, - exp.to_identifier( - self._normalize_name( - value, dialect=dialect, is_table=True, normalize=normalize - ) - ), - ) + normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) + + if normalize: + for arg in exp.TABLE_PARTS: + value = normalized_table.args.get(arg) + if isinstance(value, exp.Identifier): + normalized_table.set( + arg, + normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), + ) return normalized_table @@ -413,7 +407,7 @@ class MappingSchema(AbstractMappingSchema, Schema): dialect=dialect or self.dialect, is_table=is_table, normalize=self.normalize if normalize is None else normalize, - ) + ).name def depth(self) -> int: if not self.empty and not self._depth: @@ -451,16 +445,16 @@ def normalize_name( dialect: DialectType = None, is_table: bool = False, normalize: t.Optional[bool] = True, -) -> str: +) -> exp.Identifier: if isinstance(identifier, str): identifier = exp.parse_identifier(identifier, dialect=dialect) if not normalize: - return identifier.name + return identifier - # This can be useful for normalize_identifier + # this is used for normalize_identifier, bigquery has special rules pertaining tables identifier.meta["is_table"] = is_table - return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name + return Dialect.get_or_raise(dialect).normalize_identifier(identifier) def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: diff --git a/sqlglot/time.py b/sqlglot/time.py index c286ec1..50ec2ec 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -42,6 +42,10 @@ def format_time( end -= 1 chars = sym sym = None + else: + chars = chars[0] + end = start + 1 + start += len(chars) chunks.append(chars) current = trie diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 9784c63..e4c3204 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -7,6 +7,9 @@ from sqlglot.errors import TokenError from sqlglot.helper import AutoName from sqlglot.trie import TrieResult, in_trie, new_trie +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + class TokenType(AutoName): L_PAREN = auto() @@ -34,6 +37,7 @@ class TokenType(AutoName): EQ = auto() NEQ = auto() NULLSAFE_EQ = auto() + COLON_EQ = auto() AND = auto() OR = auto() AMP = auto() @@ -56,6 +60,7 @@ class TokenType(AutoName): SESSION_PARAMETER = auto() DAMP = auto() XOR = auto() + DSTAR = auto() BLOCK_START = auto() BLOCK_END = auto() @@ -274,6 +279,7 @@ class TokenType(AutoName): OBJECT_IDENTIFIER = auto() OFFSET = auto() ON = auto() + OPERATOR = auto() ORDER_BY = auto() ORDERED = auto() ORDINALITY = auto() @@ -295,6 +301,7 @@ class TokenType(AutoName): QUOTE = auto() RANGE = auto() RECURSIVE = auto() + REFRESH = auto() REPLACE = auto() RETURNING = auto() REFERENCES = auto() @@ -371,7 +378,7 @@ class Token: col: int = 1, start: int = 0, end: int = 0, - comments: t.List[str] = [], + comments: t.Optional[t.List[str]] = None, ) -> None: """Token initializer. @@ -390,7 +397,7 @@ class Token: self.col = col self.start = start self.end = end - self.comments = comments + self.comments = [] if comments is None else comments def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) @@ -497,11 +504,8 @@ class Tokenizer(metaclass=_Tokenizer): QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] STRING_ESCAPES = ["'"] VAR_SINGLE_TOKENS: t.Set[str] = set() - ESCAPE_SEQUENCES: t.Dict[str, str] = {} # Autofilled - IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False - _COMMENTS: t.Dict[str, str] = {} _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} _IDENTIFIERS: t.Dict[str, str] = {} @@ -523,6 +527,7 @@ class Tokenizer(metaclass=_Tokenizer): "<=": TokenType.LTE, "<>": TokenType.NEQ, "!=": TokenType.NEQ, + ":=": TokenType.COLON_EQ, "<=>": TokenType.NULLSAFE_EQ, "->": TokenType.ARROW, "->>": TokenType.DARROW, @@ -689,17 +694,22 @@ class Tokenizer(metaclass=_Tokenizer): "BOOLEAN": TokenType.BOOLEAN, "BYTE": TokenType.TINYINT, "MEDIUMINT": TokenType.MEDIUMINT, + "INT1": TokenType.TINYINT, "TINYINT": TokenType.TINYINT, + "INT16": TokenType.SMALLINT, "SHORT": TokenType.SMALLINT, "SMALLINT": TokenType.SMALLINT, "INT128": TokenType.INT128, + "HUGEINT": TokenType.INT128, "INT2": TokenType.SMALLINT, "INTEGER": TokenType.INT, "INT": TokenType.INT, "INT4": TokenType.INT, + "INT32": TokenType.INT, + "INT64": TokenType.BIGINT, "LONG": TokenType.BIGINT, "BIGINT": TokenType.BIGINT, - "INT8": TokenType.BIGINT, + "INT8": TokenType.TINYINT, "DEC": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL, "BIGDECIMAL": TokenType.BIGDECIMAL, @@ -781,7 +791,6 @@ class Tokenizer(metaclass=_Tokenizer): "\t": TokenType.SPACE, "\n": TokenType.BREAK, "\r": TokenType.BREAK, - "\r\n": TokenType.BREAK, } COMMANDS = { @@ -803,6 +812,7 @@ class Tokenizer(metaclass=_Tokenizer): "sql", "size", "tokens", + "dialect", "_start", "_current", "_line", @@ -814,7 +824,10 @@ class Tokenizer(metaclass=_Tokenizer): "_prev_token_line", ) - def __init__(self) -> None: + def __init__(self, dialect: DialectType = None) -> None: + from sqlglot.dialects import Dialect + + self.dialect = Dialect.get_or_raise(dialect) self.reset() def reset(self) -> None: @@ -850,13 +863,26 @@ class Tokenizer(metaclass=_Tokenizer): def _scan(self, until: t.Optional[t.Callable] = None) -> None: while self.size and not self._end: - self._start = self._current - self._advance() + current = self._current + + # skip spaces inline rather than iteratively call advance() + # for performance reasons + while current < self.size: + char = self.sql[current] + + if char.isspace() and (char == " " or char == "\t"): + current += 1 + else: + break + + n = current - self._current + self._start = current + self._advance(n if n > 1 else 1) if self._char is None: break - if self._char not in self.WHITE_SPACE: + if not self._char.isspace(): if self._char.isdigit(): self._scan_number() elif self._char in self._IDENTIFIERS: @@ -881,6 +907,10 @@ class Tokenizer(metaclass=_Tokenizer): def _advance(self, i: int = 1, alnum: bool = False) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: + # Ensures we don't count an extra line if we get a \r\n line break sequence + if self._char == "\r" and self._peek == "\n": + i = 2 + self._col = 1 self._line += 1 else: @@ -982,7 +1012,7 @@ class Tokenizer(metaclass=_Tokenizer): if end < self.size: char = self.sql[end] single_token = single_token or char in self.SINGLE_TOKENS - is_space = char in self.WHITE_SPACE + is_space = char.isspace() if not is_space or not prev_space: if is_space: @@ -994,7 +1024,7 @@ class Tokenizer(metaclass=_Tokenizer): skip = True else: char = "" - chars = " " + break if word: if self._scan_string(word): @@ -1086,7 +1116,7 @@ class Tokenizer(metaclass=_Tokenizer): self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") return self._add(token_type, literal) - elif self.IDENTIFIERS_CAN_START_WITH_DIGIT: + elif self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT: return self._add(TokenType.VAR) self._advance(-len(literal)) @@ -1208,8 +1238,12 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") - if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES: - escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek) + if ( + self.dialect.ESCAPE_SEQUENCES + and self._peek + and self._char in self.STRING_ESCAPES + ): + escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek) if escaped_sequence: self._advance(2) text += escaped_sequence diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 445fda6..03acc2b 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -141,7 +141,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr def unnest_to_explode(expression: exp.Expression) -> exp.Expression: - """Convert cross join unnest into lateral view explode (used in presto -> hive).""" + """Convert cross join unnest into lateral view explode.""" if isinstance(expression, exp.Select): for join in expression.args.get("joins") or []: unnest = join.this @@ -166,7 +166,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: - """Convert explode/posexplode into unnest (used in hive -> presto).""" + """Convert explode/posexplode into unnest.""" def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Select): @@ -199,11 +199,11 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp explode_alias = "" if isinstance(select, exp.Alias): - explode_alias = select.alias + explode_alias = select.args["alias"] alias = select elif isinstance(select, exp.Aliases): - pos_alias = select.aliases[0].name - explode_alias = select.aliases[1].name + pos_alias = select.aliases[0] + explode_alias = select.aliases[1] alias = select.replace(exp.alias_(select.this, "", copy=False)) else: alias = select.replace(exp.alias_(select, "")) @@ -230,9 +230,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp alias.set("alias", exp.to_identifier(explode_alias)) + series_table_alias = series.args["alias"].this column = exp.If( - this=exp.column(series_alias).eq(exp.column(pos_alias)), - true=exp.column(explode_alias), + this=exp.column(series_alias, table=series_table_alias).eq( + exp.column(pos_alias, table=unnest_source_alias) + ), + true=exp.column(explode_alias, table=unnest_source_alias), ) explode.replace(column) @@ -242,8 +245,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp expressions.insert( expressions.index(alias) + 1, exp.If( - this=exp.column(series_alias).eq(exp.column(pos_alias)), - true=exp.column(pos_alias), + this=exp.column(series_alias, table=series_table_alias).eq( + exp.column(pos_alias, table=unnest_source_alias) + ), + true=exp.column(pos_alias, table=unnest_source_alias), ).as_(pos_alias), ) expression.set("expressions", expressions) @@ -276,10 +281,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp size = size - 1 expression.where( - exp.column(series_alias) - .eq(exp.column(pos_alias)) + exp.column(series_alias, table=series_table_alias) + .eq(exp.column(pos_alias, table=unnest_source_alias)) .or_( - (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) + (exp.column(series_alias, table=series_table_alias) > size).and_( + exp.column(pos_alias, table=unnest_source_alias).eq(size) + ) ), copy=False, ) @@ -386,14 +393,16 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: full_outer_joins = [ (index, join) for index, join in enumerate(expression.args.get("joins") or []) - if join.side == "FULL" and join.kind == "OUTER" + if join.side == "FULL" ] if len(full_outer_joins) == 1: expression_copy = expression.copy() + expression.set("limit", None) index, full_outer_join = full_outer_joins[0] full_outer_join.set("side", "left") expression_copy.args["joins"][index].set("side", "right") + expression_copy.args.pop("with", None) # remove CTEs from RIGHT side return exp.union(expression, expression_copy, copy=False) @@ -430,6 +439,33 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: return expression +def ensure_bools(expression: exp.Expression) -> exp.Expression: + """Converts numeric values used in conditions into explicit boolean expressions.""" + from sqlglot.optimizer.canonicalize import ensure_bools + + def _ensure_bool(node: exp.Expression) -> None: + if ( + node.is_number + or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) + or (isinstance(node, exp.Column) and not node.type) + ): + node.replace(node.neq(0)) + + for node, *_ in expression.walk(): + ensure_bools(node, _ensure_bool) + + return expression + + +def unqualify_columns(expression: exp.Expression) -> exp.Expression: + for column in expression.find_all(exp.Column): + # We only wanna pop off the table, db, catalog args + for part in column.parts[:-1]: + part.pop() + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: |