diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 28 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 79 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 101 |
12 files changed, 230 insertions, 23 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index fd9965c..df9065f 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -211,6 +211,10 @@ class BigQuery(Dialect): "TZH": "%z", } + # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement + # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table + PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"} + @classmethod def normalize_identifier(cls, expression: E) -> E: # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 8f60df2..ce1a486 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -380,7 +380,7 @@ class ClickHouse(Dialect): ] def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: - params = self.expressions(expression, "params", flat=True) + params = self.expressions(expression, key="params", flat=True) return self.func(expression.name, *expression.expressions) + f"({params})" def placeholder_sql(self, expression: exp.Placeholder) -> str: diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8c84639..05e81ce 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -5,6 +5,7 @@ from enum import Enum from sqlglot import exp from sqlglot._typing import E +from sqlglot.errors import ParseError from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser @@ -168,6 +169,10 @@ class Dialect(metaclass=_Dialect): # special syntax cast(x as date format 'yyyy') defaults to time_mapping FORMAT_MAPPING: t.Dict[str, str] = {} + # Columns that are auto-generated by the engine corresponding to this dialect + # Such columns may be excluded from SELECT * queries, for example + PSEUDOCOLUMNS: t.Set[str] = set() + # Autofilled tokenizer_class = Tokenizer parser_class = Parser @@ -497,6 +502,10 @@ def parse_date_delta_with_interval( return None interval = args[1] + + if not isinstance(interval, exp.Interval): + raise ParseError(f"INTERVAL expression expected but got '{interval}'") + expression = interval.this if expression and expression.is_string: expression = exp.Literal.number(expression.this) @@ -555,11 +564,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: - return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" + return self.sql(exp.cast(expression.this, "timestamp")) def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: - return f"CAST({self.sql(expression, 'this')} AS DATE)" + return self.sql(exp.cast(expression.this, "date")) def min_or_least(self: Generator, expression: exp.Min) -> str: @@ -608,8 +617,9 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: _dialect = Dialect.get_or_raise(dialect) time_format = self.format_time(expression) if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): - return f"CAST({str_to_time_sql(self, expression)} AS DATE)" - return f"CAST({self.sql(expression, 'this')} AS DATE)" + return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) + + return self.sql(exp.cast(self.sql(expression, "this"), "date")) return _ts_or_ds_to_date_sql @@ -664,5 +674,15 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp return names +def simplify_literal(expression: E, copy: bool = True) -> E: + if not isinstance(expression.expression, exp.Literal): + from sqlglot.optimizer.simplify import simplify + + expression = exp.maybe_copy(expression, copy) + simplify(expression.expression) + + return expression + + def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index e131434..4e84085 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -359,14 +359,16 @@ class Hive(Dialect): TABLE_HINTS = False QUERY_HINTS = False INDEX_ON = "ON TABLE" + EXTRACT_ALLOWS_QUOTES = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.BIT: "BOOLEAN", exp.DataType.Type.DATETIME: "TIMESTAMP", - exp.DataType.Type.VARBINARY: "BINARY", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIME: "TIMESTAMP", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.BIT: "BOOLEAN", + exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { @@ -396,6 +398,7 @@ class Hive(Dialect): exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, exp.ILike: no_ilike_sql, + exp.IsNan: rename_func("ISNAN"), exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONFormat: _json_format_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5d65f77..a54f076 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, parse_date_delta_with_interval, rename_func, + simplify_literal, strposition_to_locate_sql, ) from sqlglot.helper import seq_get @@ -303,6 +304,22 @@ class MySQL(Dialect): "NAMES": lambda self: self._parse_set_item_names(), } + CONSTRAINT_PARSERS = { + **parser.Parser.CONSTRAINT_PARSERS, + "FULLTEXT": lambda self: self._parse_index_constraint(kind="FULLTEXT"), + "INDEX": lambda self: self._parse_index_constraint(), + "KEY": lambda self: self._parse_index_constraint(), + "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"), + } + + SCHEMA_UNNAMED_CONSTRAINTS = { + *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, + "FULLTEXT", + "INDEX", + "KEY", + "SPATIAL", + } + PROFILE_TYPES = { "ALL", "BLOCK IO", @@ -327,6 +344,57 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True + def _parse_index_constraint( + self, kind: t.Optional[str] = None + ) -> exp.IndexColumnConstraint: + if kind: + self._match_texts({"INDEX", "KEY"}) + + this = self._parse_id_var(any_token=False) + type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text + schema = self._parse_schema() + + options = [] + while True: + if self._match_text_seq("KEY_BLOCK_SIZE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(key_block_size=self._parse_number()) + elif self._match(TokenType.USING): + opt = exp.IndexConstraintOption(using=self._advance_any() and self._prev.text) + elif self._match_text_seq("WITH", "PARSER"): + opt = exp.IndexConstraintOption(parser=self._parse_var(any_token=True)) + elif self._match(TokenType.COMMENT): + opt = exp.IndexConstraintOption(comment=self._parse_string()) + elif self._match_text_seq("VISIBLE"): + opt = exp.IndexConstraintOption(visible=True) + elif self._match_text_seq("INVISIBLE"): + opt = exp.IndexConstraintOption(visible=False) + elif self._match_text_seq("ENGINE_ATTRIBUTE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) + elif self._match_text_seq("ENGINE_ATTRIBUTE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) + elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string()) + else: + opt = None + + if not opt: + break + + options.append(opt) + + return self.expression( + exp.IndexColumnConstraint, + this=this, + schema=schema, + kind=kind, + type=type_, + options=options, + ) + def _parse_show_mysql( self, this: str, @@ -454,6 +522,7 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")), exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, @@ -485,6 +554,16 @@ class MySQL(Dialect): exp.DataType.Type.VARCHAR: "CHAR", } + def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: + # MySQL requires simple literal values for its LIMIT clause. + expression = simplify_literal(expression) + return super().limit_sql(expression, top=top) + + def offset_sql(self, expression: exp.Offset) -> str: + # MySQL requires simple literal values for its OFFSET clause. + expression = simplify_literal(expression) + return super().offset_sql(expression) + def xor_sql(self, expression: exp.Xor) -> str: if expression.expressions: return self.expressions(expression, sep=" XOR ") diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 69da133..1f63e9f 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -30,6 +30,9 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: class Oracle(Dialect): ALIAS_POST_TABLESAMPLE = True + # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm + RESOLVES_IDENTIFIERS_AS_UPPERCASE = True + # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes TIME_MAPPING = { diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index d11cbd7..ef100b1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, rename_func, + simplify_literal, str_position_sql, timestamptrunc_sql, timestrtotime_sql, @@ -39,16 +40,13 @@ DATE_DIFF_FACTOR = { def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - from sqlglot.optimizer.simplify import simplify - this = self.sql(expression, "this") unit = expression.args.get("unit") - expression = simplify(expression.args["expression"]) + expression = simplify_literal(expression.copy(), copy=False).expression if not isinstance(expression, exp.Literal): self.unsupported("Cannot add non literal") - expression = expression.copy() expression.args["is_string"] = True return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}" diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 265c6e5..14ec3dd 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -192,6 +192,8 @@ class Presto(Dialect): "START": TokenType.BEGIN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "ROW": TokenType.STRUCT, + "IPADDRESS": TokenType.IPADDRESS, + "IPPREFIX": TokenType.IPPREFIX, } class Parser(parser.Parser): diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 73f4370..b9aaa66 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get @@ -47,7 +48,11 @@ class Spark(Spark2): exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)", exp.DataType.Type.UNIQUEIDENTIFIER: "STRING", } - TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() + + TRANSFORMS = { + **Spark2.Generator.TRANSFORMS, + exp.StartsWith: rename_func("STARTSWITH"), + } TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index dcaa524..ceb48f8 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -19,9 +19,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str: kind = e.args["kind"] properties = e.args.get("properties") - if kind.upper() == "TABLE" and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) + if ( + kind.upper() == "TABLE" + and e.expression + and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) ): return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" return create_with_partitions_sql(self, e) diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 4e8ffb4..3fac4f5 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -33,8 +33,10 @@ class Teradata(Dialect): **tokens.Tokenizer.KEYWORDS, "^=": TokenType.NEQ, "BYTEINT": TokenType.SMALLINT, + "COLLECT": TokenType.COMMAND, "GE": TokenType.GTE, "GT": TokenType.GT, + "HELP": TokenType.COMMAND, "INS": TokenType.INSERT, "LE": TokenType.LTE, "LT": TokenType.LT, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 01d5001..0eb0906 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import re import typing as t @@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import ( min_or_least, parse_date_delta, rename_func, + timestrtotime_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{ # N = Numeric, C=Currency TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} +DEFAULT_START_DATE = datetime.date(1900, 1, 1) + def _format_time_lambda( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None @@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s return f"STRING_AGG({self.format_args(this, separator)}){order}" +def _parse_date_delta( + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None +) -> t.Callable[[t.List], E]: + def inner_func(args: t.List) -> E: + unit = seq_get(args, 0) + if unit and unit_mapping: + unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) + + start_date = seq_get(args, 1) + if start_date and start_date.is_number: + # Numeric types are valid DATETIME values + if start_date.is_int: + adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this)) + start_date = exp.Literal.string(adds.strftime("%F")) + else: + # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs. + # This is not a problem when generating T-SQL code, it is when transpiling to other dialects. + return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit) + + return exp_class( + this=exp.TimeStrToTime(this=seq_get(args, 2)), + expression=exp.TimeStrToTime(this=start_date), + unit=unit, + ) + + return inner_func + + class TSQL(Dialect): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" @@ -298,7 +330,6 @@ class TSQL(Dialect): "SMALLDATETIME": TokenType.DATETIME, "SMALLMONEY": TokenType.SMALLMONEY, "SQL_VARIANT": TokenType.VARIANT, - "TIME": TokenType.TIMESTAMP, "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, @@ -307,10 +338,6 @@ class TSQL(Dialect): "SYSTEM_USER": TokenType.CURRENT_USER, } - # TSQL allows @, # to appear as a variable/identifier prefix - SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("#") - class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -320,7 +347,7 @@ class TSQL(Dialect): position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), - "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), "EOMONTH": _parse_eomonth, @@ -518,6 +545,36 @@ class TSQL(Dialect): expressions = self._parse_csv(self._parse_function_parameter) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + is_temporary = self._match(TokenType.HASH) + is_global = is_temporary and self._match(TokenType.HASH) + + this = super()._parse_id_var(any_token=any_token, tokens=tokens) + if this: + if is_global: + this.set("global", True) + elif is_temporary: + this.set("temporary", True) + + return this + + def _parse_create(self) -> exp.Create | exp.Command: + create = super()._parse_create() + + if isinstance(create, exp.Create): + table = create.this.this if isinstance(create.this, exp.Schema) else create.this + if isinstance(table, exp.Table) and table.this.args.get("temporary"): + if not create.args.get("properties"): + create.set("properties", exp.Properties(expressions=[])) + + create.args["properties"].append("expressions", exp.TemporaryProperty()) + + return create + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True @@ -526,9 +583,11 @@ class TSQL(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.TIMESTAMP: "DATETIME2", + exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", exp.DataType.Type.VARIANT: "SQL_VARIANT", } @@ -552,6 +611,8 @@ class TSQL(Dialect): exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this, ), + exp.TemporaryProperty: lambda self, e: "", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, } @@ -564,6 +625,22 @@ class TSQL(Dialect): LIMIT_FETCH = "FETCH" + def createable_sql( + self, + expression: exp.Create, + locations: dict[exp.Properties.Location, list[exp.Property]], + ) -> str: + sql = self.sql(expression, "this") + properties = expression.args.get("properties") + + if sql[:1] != "#" and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ): + sql = f"#{sql}" + + return sql + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" @@ -616,3 +693,13 @@ class TSQL(Dialect): this = self.sql(expression, "this") this = f" {this}" if this else "" return f"ROLLBACK TRANSACTION{this}" + + def identifier_sql(self, expression: exp.Identifier) -> str: + identifier = super().identifier_sql(expression) + + if expression.args.get("global"): + identifier = f"##{identifier}" + elif expression.args.get("temporary"): + identifier = f"#{identifier}" + + return identifier |