summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/__init__.py61
-rw-r--r--sqlglot/dialects/bigquery.py1
-rw-r--r--sqlglot/dialects/clickhouse.py9
-rw-r--r--sqlglot/dialects/mysql.py4
-rw-r--r--sqlglot/dialects/snowflake.py8
-rw-r--r--sqlglot/dialects/tsql.py116
6 files changed, 183 insertions, 16 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 2084681..34cf613 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -1,3 +1,64 @@
+"""
+## Dialects
+
+One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a
+"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`,
+`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming
+SQL code.
+
+However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such
+example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to
+override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new
+dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class),
+in order to implement the target behavior.
+
+
+### Implementing a custom Dialect
+
+Consider the following example:
+
+```python
+from sqlglot import exp
+from sqlglot.dialects.dialect import Dialect
+from sqlglot.generator import Generator
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+class Custom(Dialect):
+ class Tokenizer(Tokenizer):
+ QUOTES = ["'", '"']
+ IDENTIFIERS = ["`"]
+
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "INT64": TokenType.BIGINT,
+ "FLOAT64": TokenType.DOUBLE,
+ }
+
+ class Generator(Generator):
+ TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
+
+ TYPE_MAPPING = {
+ exp.DataType.Type.TINYINT: "INT64",
+ exp.DataType.Type.SMALLINT: "INT64",
+ exp.DataType.Type.INT: "INT64",
+ exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.DECIMAL: "NUMERIC",
+ exp.DataType.Type.FLOAT: "FLOAT64",
+ exp.DataType.Type.DOUBLE: "FLOAT64",
+ exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.TEXT: "STRING",
+ }
+```
+
+This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
+delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
+the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
+logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
+
+----
+"""
+
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 9ddfbea..e7d30ec 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -124,7 +124,6 @@ class BigQuery(Dialect):
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"NOT DETERMINISTIC": TokenType.VOLATILE,
- "QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
}
KEYWORDS.pop("DIV")
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 1c173a4..9e8c691 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -73,13 +73,8 @@ class ClickHouse(Dialect):
return this
- def _parse_position(self) -> exp.Expression:
- this = super()._parse_position()
- # clickhouse position args are swapped
- substr = this.this
- this.args["this"] = this.args.get("substr")
- this.args["substr"] = substr
- return this
+ def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
+ return super()._parse_position(haystack_first=True)
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression:
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1bddfe1..2a0a917 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -124,6 +124,8 @@ class MySQL(Dialect):
**tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"LONGTEXT": TokenType.LONGTEXT,
+ "MEDIUMBLOB": TokenType.MEDIUMBLOB,
+ "LONGBLOB": TokenType.LONGBLOB,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
@@ -459,6 +461,8 @@ class MySQL(Dialect):
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
+ TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
+ TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index c44950a..6225a53 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -194,7 +194,8 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "QUALIFY": TokenType.QUALIFY,
+ "EXCLUDE": TokenType.EXCEPT,
+ "RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -232,6 +233,11 @@ class Snowflake(Dialect):
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
+ STAR_MAPPING = {
+ "except": "EXCLUDE",
+ "replace": "RENAME",
+ }
+
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 9342e6b..9f9099e 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import re
+import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
@@ -251,6 +252,7 @@ class TSQL(Dialect):
"NTEXT": TokenType.TEXT,
"NVARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
+ "PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME,
@@ -263,6 +265,11 @@ class TSQL(Dialect):
"XML": TokenType.XML,
}
+ # TSQL allows @, # to appear as a variable/identifier prefix
+ SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
+ SINGLE_TOKENS.pop("@")
+ SINGLE_TOKENS.pop("#")
+
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
@@ -293,26 +300,82 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
- # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
- TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
+ RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
+ TokenType.TABLE,
+ *parser.Parser.TYPE_TOKENS, # type: ignore
+ }
+
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ TokenType.END: lambda self: self._parse_command(),
+ }
+
+ def _parse_system_time(self) -> t.Optional[exp.Expression]:
+ if not self._match_text_seq("FOR", "SYSTEM_TIME"):
+ return None
- def _parse_convert(self, strict):
+ if self._match_text_seq("AS", "OF"):
+ system_time = self.expression(
+ exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
+ )
+ elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
+ kind = self._prev.text
+ this = self._parse_bitwise()
+ self._match_texts(("TO", "AND"))
+ expression = self._parse_bitwise()
+ system_time = self.expression(
+ exp.SystemTime, this=this, expression=expression, kind=kind
+ )
+ elif self._match_text_seq("CONTAINED", "IN"):
+ args = self._parse_wrapped_csv(self._parse_bitwise)
+ system_time = self.expression(
+ exp.SystemTime,
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ kind="CONTAINED IN",
+ )
+ elif self._match(TokenType.ALL):
+ system_time = self.expression(exp.SystemTime, kind="ALL")
+ else:
+ system_time = None
+ self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
+
+ return system_time
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ table = super()._parse_table_parts(schema=schema)
+ table.set("system_time", self._parse_system_time())
+ return table
+
+ def _parse_returns(self) -> exp.Expression:
+ table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
+ returns = super()._parse_returns()
+ returns.set("table", table)
+ return returns
+
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_conjunction()
+ if not to or not this:
+ return None
+
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
- format_val = self._parse_number().name
- if format_val not in TSQL.convert_format_mapping:
+ format_val = self._parse_number()
+ format_val_name = format_val.name if format_val else ""
+
+ if format_val_name not in TSQL.convert_format_mapping:
raise ValueError(
- f"CONVERT function at T-SQL does not support format style {format_val}"
+ f"CONVERT function at T-SQL does not support format style {format_val_name}"
)
- format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
+
+ format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
# Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE:
@@ -333,6 +396,21 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_user_defined_function(
+ self, kind: t.Optional[TokenType] = None
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_user_defined_function(kind=kind)
+
+ if (
+ kind == TokenType.FUNCTION
+ or isinstance(this, exp.UserDefinedFunction)
+ or self._match(TokenType.ALIAS, advance=False)
+ ):
+ return this
+
+ expressions = self._parse_csv(self._parse_udf_kwarg)
+ return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -354,3 +432,27 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
}
+
+ TRANSFORMS.pop(exp.ReturnsProperty)
+
+ def systemtime_sql(self, expression: exp.SystemTime) -> str:
+ kind = expression.args["kind"]
+ if kind == "ALL":
+ return "FOR SYSTEM_TIME ALL"
+
+ start = self.sql(expression, "this")
+ if kind == "AS OF":
+ return f"FOR SYSTEM_TIME AS OF {start}"
+
+ end = self.sql(expression, "expression")
+ if kind == "FROM":
+ return f"FOR SYSTEM_TIME FROM {start} TO {end}"
+ if kind == "BETWEEN":
+ return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
+
+ return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
+
+ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
+ table = expression.args.get("table")
+ table = f"{table} " if table else ""
+ return f"RETURNS {table}{self.sql(expression, 'this')}"