diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/__init__.py | 61 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 116 |
6 files changed, 183 insertions, 16 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 2084681..34cf613 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -1,3 +1,64 @@ +""" +## Dialects + +One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a +"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`, +`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming +SQL code. + +However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such +example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to +override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new +dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class), +in order to implement the target behavior. + + +### Implementing a custom Dialect + +Consider the following example: + +```python +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.tokens import Tokenizer, TokenType + + +class Custom(Dialect): + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + IDENTIFIERS = ["`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + } + + class Generator(Generator): + TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"} + + TYPE_MAPPING = { + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + } +``` + +This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string +delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since +the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation +logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping. + +---- +""" + from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 9ddfbea..e7d30ec 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -124,7 +124,6 @@ class BigQuery(Dialect): "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, "NOT DETERMINISTIC": TokenType.VOLATILE, - "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, } KEYWORDS.pop("DIV") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 1c173a4..9e8c691 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -73,13 +73,8 @@ class ClickHouse(Dialect): return this - def _parse_position(self) -> exp.Expression: - this = super()._parse_position() - # clickhouse position args are swapped - substr = this.this - this.args["this"] = this.args.get("substr") - this.args["substr"] = substr - return this + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: + return super()._parse_position(haystack_first=True) # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ def _parse_cte(self) -> exp.Expression: diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1bddfe1..2a0a917 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -124,6 +124,8 @@ class MySQL(Dialect): **tokens.Tokenizer.KEYWORDS, "MEDIUMTEXT": TokenType.MEDIUMTEXT, "LONGTEXT": TokenType.LONGTEXT, + "MEDIUMBLOB": TokenType.MEDIUMBLOB, + "LONGBLOB": TokenType.LONGBLOB, "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, @@ -459,6 +461,8 @@ class MySQL(Dialect): TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) + TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) + TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index c44950a..6225a53 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -194,7 +194,8 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "QUALIFY": TokenType.QUALIFY, + "EXCLUDE": TokenType.EXCEPT, + "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -232,6 +233,11 @@ class Snowflake(Dialect): exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } + STAR_MAPPING = { + "except": "EXCLUDE", + "replace": "RENAME", + } + ROOT_PROPERTIES = { exp.PartitionedByProperty, exp.ReturnsProperty, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9342e6b..9f9099e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func @@ -251,6 +252,7 @@ class TSQL(Dialect): "NTEXT": TokenType.TEXT, "NVARCHAR(MAX)": TokenType.TEXT, "PRINT": TokenType.COMMAND, + "PROC": TokenType.PROCEDURE, "REAL": TokenType.FLOAT, "ROWVERSION": TokenType.ROWVERSION, "SMALLDATETIME": TokenType.DATETIME, @@ -263,6 +265,11 @@ class TSQL(Dialect): "XML": TokenType.XML, } + # TSQL allows @, # to appear as a variable/identifier prefix + SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() + SINGLE_TOKENS.pop("@") + SINGLE_TOKENS.pop("#") + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore @@ -293,26 +300,82 @@ class TSQL(Dialect): DataType.Type.NCHAR, } - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table - TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER} + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore + TokenType.TABLE, + *parser.Parser.TYPE_TOKENS, # type: ignore + } + + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, # type: ignore + TokenType.END: lambda self: self._parse_command(), + } + + def _parse_system_time(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("FOR", "SYSTEM_TIME"): + return None - def _parse_convert(self, strict): + if self._match_text_seq("AS", "OF"): + system_time = self.expression( + exp.SystemTime, this=self._parse_bitwise(), kind="AS OF" + ) + elif self._match_set((TokenType.FROM, TokenType.BETWEEN)): + kind = self._prev.text + this = self._parse_bitwise() + self._match_texts(("TO", "AND")) + expression = self._parse_bitwise() + system_time = self.expression( + exp.SystemTime, this=this, expression=expression, kind=kind + ) + elif self._match_text_seq("CONTAINED", "IN"): + args = self._parse_wrapped_csv(self._parse_bitwise) + system_time = self.expression( + exp.SystemTime, + this=seq_get(args, 0), + expression=seq_get(args, 1), + kind="CONTAINED IN", + ) + elif self._match(TokenType.ALL): + system_time = self.expression(exp.SystemTime, kind="ALL") + else: + system_time = None + self.raise_error("Unable to parse FOR SYSTEM_TIME clause") + + return system_time + + def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + table = super()._parse_table_parts(schema=schema) + table.set("system_time", self._parse_system_time()) + return table + + def _parse_returns(self) -> exp.Expression: + table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) + returns = super()._parse_returns() + returns.set("table", table) + return returns + + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_conjunction() + if not to or not this: + return None + # Retrieve length of datatype and override to default if not specified if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) # Check whether a conversion with format is applicable if self._match(TokenType.COMMA): - format_val = self._parse_number().name - if format_val not in TSQL.convert_format_mapping: + format_val = self._parse_number() + format_val_name = format_val.name if format_val else "" + + if format_val_name not in TSQL.convert_format_mapping: raise ValueError( - f"CONVERT function at T-SQL does not support format style {format_val}" + f"CONVERT function at T-SQL does not support format style {format_val_name}" ) - format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) + + format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name]) # Check whether the convert entails a string to date format if to.this == DataType.Type.DATE: @@ -333,6 +396,21 @@ class TSQL(Dialect): # Entails a simple cast without any format requirement return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_user_defined_function( + self, kind: t.Optional[TokenType] = None + ) -> t.Optional[exp.Expression]: + this = super()._parse_user_defined_function(kind=kind) + + if ( + kind == TokenType.FUNCTION + or isinstance(this, exp.UserDefinedFunction) + or self._match(TokenType.ALIAS, advance=False) + ): + return this + + expressions = self._parse_csv(self._parse_udf_kwarg) + return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -354,3 +432,27 @@ class TSQL(Dialect): exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, } + + TRANSFORMS.pop(exp.ReturnsProperty) + + def systemtime_sql(self, expression: exp.SystemTime) -> str: + kind = expression.args["kind"] + if kind == "ALL": + return "FOR SYSTEM_TIME ALL" + + start = self.sql(expression, "this") + if kind == "AS OF": + return f"FOR SYSTEM_TIME AS OF {start}" + + end = self.sql(expression, "expression") + if kind == "FROM": + return f"FOR SYSTEM_TIME FROM {start} TO {end}" + if kind == "BETWEEN": + return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}" + + return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})" + + def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: + table = expression.args.get("table") + table = f"{table} " if table else "" + return f"RETURNS {table}{self.sql(expression, 'this')}" |