From 7ff5bab54e3298dd89132706f6adee17f5164f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 5 Nov 2022 19:41:12 +0100 Subject: Merging upstream version 9.0.6. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 4 +- sqlglot/dialects/__init__.py | 1 + sqlglot/dialects/databricks.py | 21 +++++ sqlglot/dialects/dialect.py | 13 +++ sqlglot/dialects/hive.py | 2 + sqlglot/dialects/presto.py | 1 + sqlglot/dialects/snowflake.py | 2 + sqlglot/dialects/sqlite.py | 1 + sqlglot/dialects/tsql.py | 78 ++++++++++++++---- sqlglot/expressions.py | 151 ++++++++++++++++++++++++----------- sqlglot/generator.py | 20 +++-- sqlglot/optimizer/qualify_columns.py | 21 +++-- sqlglot/optimizer/scope.py | 2 +- sqlglot/parser.py | 83 +++++++++++++------ sqlglot/time.py | 1 - sqlglot/tokens.py | 4 + 16 files changed, 301 insertions(+), 104 deletions(-) create mode 100644 sqlglot/dialects/databricks.py (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index a780f96..d6e18fd 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,3 +1,5 @@ +"""## Python SQL parser, transpiler and optimizer.""" + from sqlglot import expressions as exp from sqlglot.dialects import Dialect, Dialects from sqlglot.diff import diff @@ -24,7 +26,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "9.0.3" +__version__ = "9.0.6" pretty = False diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 0f80723..0816831 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -1,5 +1,6 @@ from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse +from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.hive import Hive diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py new file mode 100644 index 0000000..9dc3c38 --- /dev/null +++ b/sqlglot/dialects/databricks.py @@ -0,0 +1,21 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import parse_date_delta +from sqlglot.dialects.spark import Spark +from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql + + +class Databricks(Spark): + class Parser(Spark.Parser): + FUNCTIONS = { + **Spark.Parser.FUNCTIONS, + "DATEADD": parse_date_delta(exp.DateAdd), + "DATE_ADD": parse_date_delta(exp.DateAdd), + "DATEDIFF": parse_date_delta(exp.DateDiff), + } + + class Generator(Spark.Generator): + TRANSFORMS = { + **Spark.Generator.TRANSFORMS, + exp.DateAdd: generate_date_delta_with_unit_sql, + exp.DateDiff: generate_date_delta_with_unit_sql, + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 46661cf..33985a7 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -28,6 +28,7 @@ class Dialects(str, Enum): TABLEAU = "tableau" TRINO = "trino" TSQL = "tsql" + DATABRICKS = "databricks" class _Dialect(type): @@ -331,3 +332,15 @@ def create_with_partitions_sql(self, expression): expression.set("this", schema) return self.create_sql(expression) + + +def parse_date_delta(exp_class, unit_mapping=None): + def inner_func(args): + unit_based = len(args) == 3 + this = list_get(args, 2) if unit_based else list_get(args, 0) + expression = list_get(args, 1) if unit_based else list_get(args, 1) + unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + return exp_class(this=this, expression=expression, unit=unit) + + return inner_func diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 63fdb85..03049ff 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -111,6 +111,7 @@ def _unnest_to_explode_sql(self, expression): self.sql( exp.Lateral( this=udtf(this=expression), + view=True, alias=exp.TableAlias(this=alias.this, columns=[column]), ) ) @@ -283,6 +284,7 @@ class Hive(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + exp.NumberToStr: rename_func("FORMAT_NUMBER"), } WITH_PROPERTIES = {exp.AnonymousProperty} diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 41c0db1..a2d392c 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -115,6 +115,7 @@ class Presto(Dialect): class Tokenizer(Tokenizer): KEYWORDS = { **Tokenizer.KEYWORDS, + "VARBINARY": TokenType.BINARY, "ROW": TokenType.STRUCT, } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 627258f..3b97e6d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -188,6 +188,8 @@ class Snowflake(Dialect): } class Generator(Generator): + CREATE_TRANSIENT = True + TRANSFORMS = { **Generator.TRANSFORMS, exp.ArrayConcat: rename_func("ARRAY_CAT"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index cfdbe1b..62b7617 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -20,6 +20,7 @@ class SQLite(Dialect): KEYWORDS = { **Tokenizer.KEYWORDS, + "VARBINARY": TokenType.BINARY, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 107ace7..0f93c75 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,5 +1,7 @@ +import re + from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, rename_func +from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.expressions import DataType from sqlglot.generator import Generator from sqlglot.helper import list_get @@ -27,6 +29,11 @@ DATE_DELTA_INTERVAL = { } +DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") +# N = Numeric, C=Currency +TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} + + def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( @@ -42,26 +49,40 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def parse_date_delta(exp_class): - def inner_func(args): - unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day") - return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit) - - return inner_func +def parse_format(args): + fmt = list_get(args, 1) + number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) + if number_fmt: + return exp.NumberToStr(this=list_get(args, 0), format=fmt) + return exp.TimeToStr( + this=list_get(args, 0), + format=exp.Literal.string( + format_time(fmt.name, TSQL.format_time_mapping) + if len(fmt.name) == 1 + else format_time(fmt.name, TSQL.time_mapping) + ), + ) -def generate_date_delta(self, e): +def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" +def generate_format_sql(self, e): + fmt = ( + e.args["format"] + if isinstance(e, exp.NumberToStr) + else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) + ) + return f"FORMAT({self.format_args(e.this, fmt)})" + + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" time_mapping = { - "yyyy": "%Y", - "yy": "%y", "year": "%Y", "qq": "%q", "q": "%q", @@ -101,6 +122,8 @@ class TSQL(Dialect): "H": "%-H", "h": "%-I", "S": "%f", + "yyyy": "%Y", + "yy": "%y", } convert_format_mapping = { @@ -143,6 +166,27 @@ class TSQL(Dialect): "120": "%Y-%m-%d %H:%M:%S", "121": "%Y-%m-%d %H:%M:%S.%f", } + # not sure if complete + format_time_mapping = { + "y": "%B %Y", + "d": "%m/%d/%Y", + "H": "%-H", + "h": "%-I", + "s": "%Y-%m-%d %H:%M:%S", + "D": "%A,%B,%Y", + "f": "%A,%B,%Y %-I:%M %p", + "F": "%A,%B,%Y %-I:%M:%S %p", + "g": "%m/%d/%Y %-I:%M %p", + "G": "%m/%d/%Y %-I:%M:%S %p", + "M": "%B %-d", + "m": "%B %-d", + "O": "%Y-%m-%dT%H:%M:%S", + "u": "%Y-%M-%D %H:%M:%S%z", + "U": "%A, %B %D, %Y %H:%M:%S%z", + "T": "%-I:%M:%S %p", + "t": "%-I:%M", + "Y": "%a %Y", + } class Tokenizer(Tokenizer): IDENTIFIERS = ['"', ("[", "]")] @@ -166,6 +210,7 @@ class TSQL(Dialect): "SQL_VARIANT": TokenType.VARIANT, "NVARCHAR(MAX)": TokenType.TEXT, "VARCHAR(MAX)": TokenType.TEXT, + "TOP": TokenType.TOP, } class Parser(Parser): @@ -173,8 +218,8 @@ class TSQL(Dialect): **Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, - "DATEADD": parse_date_delta(exp.DateAdd), - "DATEDIFF": parse_date_delta(exp.DateDiff), + "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": tsql_format_time_lambda(exp.TimeToStr), "GETDATE": exp.CurrentDate.from_arg_list, @@ -182,6 +227,7 @@ class TSQL(Dialect): "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, + "FORMAT": parse_format, } VAR_LENGTH_DATATYPES = { @@ -194,7 +240,7 @@ class TSQL(Dialect): def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) - this = self._parse_field() + this = self._parse_column() # Retrieve length of datatype and override to default if not specified if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: @@ -238,8 +284,10 @@ class TSQL(Dialect): TRANSFORMS = { **Generator.TRANSFORMS, - exp.DateAdd: lambda self, e: generate_date_delta(self, e), - exp.DateDiff: lambda self, e: generate_date_delta(self, e), + exp.DateAdd: generate_date_delta_with_unit_sql, + exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.If: rename_func("IIF"), + exp.NumberToStr: generate_format_sql, + exp.TimeToStr: generate_format_sql, } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index eb7854a..1691d85 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -443,7 +443,7 @@ class Condition(Expression): 'x = 1 AND y = 1' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect (str): the dialect used to parse the input expression. opts (kwargs): other options to use to parse the input expressions. @@ -462,7 +462,7 @@ class Condition(Expression): 'x = 1 OR y = 1' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect (str): the dialect used to parse the input expression. opts (kwargs): other options to use to parse the input expressions. @@ -523,7 +523,7 @@ class Unionable(Expression): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - expression (str or Expression): the SQL code string. + expression (str | Expression): the SQL code string. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -543,7 +543,7 @@ class Unionable(Expression): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - expression (str or Expression): the SQL code string. + expression (str | Expression): the SQL code string. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -563,7 +563,7 @@ class Unionable(Expression): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - expression (str or Expression): the SQL code string. + expression (str | Expression): the SQL code string. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -612,6 +612,7 @@ class Create(Expression): "exists": False, "properties": False, "temporary": False, + "transient": False, "replace": False, "unique": False, "materialized": False, @@ -910,7 +911,7 @@ class Join(Expression): 'JOIN x ON y = 1' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. append (bool): if `True`, AND the new expressions to any existing expression. @@ -937,9 +938,45 @@ class Join(Expression): return join + def using(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the USING expressions. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("JOIN x", into=Join).using("foo", "bla").sql() + 'JOIN x USING (foo, bla)' + + Args: + *expressions (str | Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, concatenate the new expressions to the existing "using" list. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Join: the modified join expression. + """ + join = _apply_list_builder( + *expressions, + instance=self, + arg="using", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + if join.kind == "CROSS": + join.set("kind", None) + + return join + class Lateral(UDTF): - arg_types = {"this": True, "outer": False, "alias": False} + arg_types = {"this": True, "view": False, "outer": False, "alias": False} # Clickhouse FROM FINAL modifier @@ -1093,7 +1130,7 @@ class Subqueryable(Unionable): 'SELECT x FROM (SELECT x FROM tbl)' Args: - alias (str or Identifier): an optional alias for the subquery + alias (str | Identifier): an optional alias for the subquery copy (bool): if `False`, modify this expression instance in-place. Returns: @@ -1138,9 +1175,9 @@ class Subqueryable(Unionable): 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' Args: - alias (str or Expression): the SQL code string to parse as the table name. + alias (str | Expression): the SQL code string to parse as the table name. If an `Expression` instance is passed, this is used as-is. - as_ (str or Expression): the SQL code string to parse as the table expression. + as_ (str | Expression): the SQL code string to parse as the table expression. If an `Expression` instance is passed, it will be used as-is. recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. append (bool): if `True`, add to any existing expressions. @@ -1295,7 +1332,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `From` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `From`. append (bool): if `True`, add to any existing expressions. @@ -1328,7 +1365,7 @@ class Select(Subqueryable): 'SELECT x, COUNT(1) FROM tbl GROUP BY x' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Group`. If nothing is passed in then a group by is not applied to the expression @@ -1364,7 +1401,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl ORDER BY x DESC' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Order`. append (bool): if `True`, add to any existing expressions. @@ -1397,7 +1434,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl SORT BY x DESC' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `SORT`. append (bool): if `True`, add to any existing expressions. @@ -1430,7 +1467,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl CLUSTER BY x DESC' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Cluster`. append (bool): if `True`, add to any existing expressions. @@ -1463,7 +1500,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl LIMIT 10' Args: - expression (str or int or Expression): the SQL code string to parse. + expression (str | int | Expression): the SQL code string to parse. This can also be an integer. If a `Limit` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Limit`. @@ -1494,7 +1531,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl OFFSET 10' Args: - expression (str or int or Expression): the SQL code string to parse. + expression (str | int | Expression): the SQL code string to parse. This can also be an integer. If a `Offset` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Offset`. @@ -1525,7 +1562,7 @@ class Select(Subqueryable): 'SELECT x, y' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. append (bool): if `True`, add to any existing expressions. Otherwise, this resets the expressions. @@ -1555,7 +1592,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. append (bool): if `True`, add to any existing expressions. Otherwise, this resets the expressions. @@ -1582,6 +1619,7 @@ class Select(Subqueryable): self, expression, on=None, + using=None, append=True, join_type=None, join_alias=None, @@ -1596,15 +1634,20 @@ class Select(Subqueryable): >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql() 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y' + >>> Select().select("1").from_("a").join("b", using=["x", "y", "z"]).sql() + 'SELECT 1 FROM a JOIN b USING (x, y, z)' + Use `join_type` to change the type of join: >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql() 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' Args: - expression (str or Expression): the SQL code string to parse. + expression (str | Expression): the SQL code string to parse. If an `Expression` instance is passed, it will be used as-is. - on (str or Expression): optionally specify the join criteria as a SQL string. + on (str | Expression): optionally specify the join "on" criteria as a SQL string. + If an `Expression` instance is passed, it will be used as-is. + using (str | Expression): optionally specify the join "using" criteria as a SQL string. If an `Expression` instance is passed, it will be used as-is. append (bool): if `True`, add to any existing expressions. Otherwise, this resets the expressions. @@ -1641,6 +1684,16 @@ class Select(Subqueryable): on = and_(*ensure_list(on), dialect=dialect, **opts) join.set("on", on) + if using: + join = _apply_list_builder( + *ensure_list(using), + instance=join, + arg="using", + append=append, + copy=copy, + **opts, + ) + if join_alias: join.set("this", alias_(join.args["this"], join_alias, table=True)) return _apply_list_builder( @@ -1661,7 +1714,7 @@ class Select(Subqueryable): "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. append (bool): if `True`, AND the new expressions to any existing expression. @@ -1693,7 +1746,7 @@ class Select(Subqueryable): 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. append (bool): if `True`, AND the new expressions to any existing expression. @@ -1744,7 +1797,7 @@ class Select(Subqueryable): 'CREATE TABLE x AS SELECT * FROM tbl' Args: - table (str or Expression): the SQL code string to parse as the table name. + table (str | Expression): the SQL code string to parse as the table name. If another `Expression` instance is passed, it will be used as-is. properties (dict): an optional mapping of table properties dialect (str): the dialect used to parse the input table. @@ -2620,6 +2673,10 @@ class StrToUnix(Func): arg_types = {"this": True, "format": True} +class NumberToStr(Func): + arg_types = {"this": True, "format": True} + + class Struct(Func): arg_types = {"expressions": True} is_var_len_args = True @@ -2775,7 +2832,7 @@ def maybe_parse( (IDENTIFIER this: x, quoted: False) Args: - sql_or_expression (str or Expression): the SQL code string or an expression + sql_or_expression (str | Expression): the SQL code string or an expression into (Expression): the SQLGlot Expression to parse into dialect (str): the dialect used to parse the input expressions (in the case that an input expression is a SQL string). @@ -2950,9 +3007,9 @@ def union(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - left (str or Expression): the SQL code string corresponding to the left-hand side. + left (str | Expression): the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str or Expression): the SQL code string corresponding to the right-hand side. + right (str | Expression): the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -2975,9 +3032,9 @@ def intersect(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - left (str or Expression): the SQL code string corresponding to the left-hand side. + left (str | Expression): the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str or Expression): the SQL code string corresponding to the right-hand side. + right (str | Expression): the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -3000,9 +3057,9 @@ def except_(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - left (str or Expression): the SQL code string corresponding to the left-hand side. + left (str | Expression): the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str or Expression): the SQL code string corresponding to the right-hand side. + right (str | Expression): the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -3025,7 +3082,7 @@ def select(*expressions, dialect=None, **opts): 'SELECT col1, col2 FROM tbl' Args: - *expressions (str or Expression): the SQL code string to parse as the expressions of a + *expressions (str | Expression): the SQL code string to parse as the expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expressions (in the case that an input expression is a SQL string). @@ -3047,7 +3104,7 @@ def from_(*expressions, dialect=None, **opts): 'SELECT col1, col2 FROM tbl' Args: - *expressions (str or Expression): the SQL code string to parse as the FROM expressions of a + *expressions (str | Expression): the SQL code string to parse as the FROM expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression (in the case that the input expression is a SQL string). @@ -3132,7 +3189,7 @@ def condition(expression, dialect=None, **opts): 'SELECT * FROM tbl WHERE x = 1 AND y = 1' Args: - *expression (str or Expression): the SQL code string to parse. + *expression (str | Expression): the SQL code string to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression (in the case that the input expression is a SQL string). @@ -3159,7 +3216,7 @@ def and_(*expressions, dialect=None, **opts): 'x = 1 AND (y = 1 AND z = 1)' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3179,7 +3236,7 @@ def or_(*expressions, dialect=None, **opts): 'x = 1 OR (y = 1 OR z = 1)' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3199,7 +3256,7 @@ def not_(expression, dialect=None, **opts): "NOT this_suit = 'black'" Args: - expression (str or Expression): the SQL code strings to parse. + expression (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3283,9 +3340,9 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): 'foo AS bar' Args: - expression (str or Expression): the SQL code strings to parse. + expression (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str or Identifier): the alias name to use. If the name has + alias (str | Identifier): the alias name to use. If the name has special characters it is quoted. table (bool): create a table alias, default false dialect (str): the dialect used to parse the input expression. @@ -3322,9 +3379,9 @@ def subquery(expression, alias=None, dialect=None, **opts): 'SELECT x FROM (SELECT x FROM tbl) AS bar' Args: - expression (str or Expression): the SQL code strings to parse. + expression (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str or Expression): the alias name to use. + alias (str | Expression): the alias name to use. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3340,8 +3397,8 @@ def column(col, table=None, quoted=None): """ Build a Column. Args: - col (str or Expression): column name - table (str or Expression): table name + col (str | Expression): column name + table (str | Expression): table name Returns: Column: column instance """ @@ -3355,9 +3412,9 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None): """Build a Table. Args: - table (str or Expression): column name - db (str or Expression): db name - catalog (str or Expression): catalog name + table (str | Expression): column name + db (str | Expression): db name + catalog (str | Expression): catalog name Returns: Table: table instance @@ -3423,7 +3480,7 @@ def convert(value): values=[convert(v) for v in value.values()], ) if isinstance(value, datetime.datetime): - datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S")) + datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z")) return TimeStrToTime(this=datetime_literal) if isinstance(value, datetime.date): date_literal = Literal.string(value.strftime("%Y-%m-%d")) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 1784287..ca14425 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -65,6 +65,9 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } + # whether 'CREATE ... TRANSIENT ... TABLE' is allowed + # can override in dialects + CREATE_TRANSIENT = False # whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True # always do union distinct or union all @@ -368,15 +371,14 @@ class Generator: expression_sql = self.sql(expression, "expression") expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" + transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") - expression_sql = ( - f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" - ) + expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression): @@ -716,15 +718,21 @@ class Generator: def lateral_sql(self, expression): this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subquery): - return f"LATERAL{self.sep()}{this}" - op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") + return f"LATERAL {this}" + alias = expression.args["alias"] table = alias.name table = f" {table}" if table else table columns = self.expressions(alias, key="columns", flat=True) columns = f" AS {columns}" if columns else "" - return f"{op_sql}{self.sep()}{this}{table}{columns}" + + if expression.args.get("view"): + op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") + return f"{op_sql}{self.sep()}{this}{table}{columns}" + + return f"LATERAL {this}{table}{columns}" def limit_sql(self, expression): this = self.sql(expression, "this") diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 36ba028..ebee92a 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -211,21 +211,26 @@ def _qualify_columns(scope, resolver): if column_table: column.set("table", exp.to_identifier(column_table)) + columns_missing_from_scope = [] # Determine whether each reference in the order by clause is to a column or an alias. for ordered in scope.find_all(exp.Ordered): for column in ordered.find_all(exp.Column): - column_table = column.table - column_name = column.name + if not column.table and column.parent is not ordered and column.name in resolver.all_columns: + columns_missing_from_scope.append(column) - if column_table or column.parent is ordered or column_name not in resolver.all_columns: - continue + # Determine whether each reference in the having clause is to a column or an alias. + for having in scope.find_all(exp.Having): + for column in having.find_all(exp.Column): + if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns: + columns_missing_from_scope.append(column) - column_table = resolver.get_table(column_name) + for column in columns_missing_from_scope: + column_table = resolver.get_table(column.name) - if column_table is None: - raise OptimizeError(f"Ambiguous column: {column_name}") + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column.name}") - column.set("table", exp.to_identifier(column_table)) + column.set("table", exp.to_identifier(column_table)) def _expand_stars(scope, resolver): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b7eb6c2..5a75ee2 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -232,7 +232,7 @@ class Scope: self._columns = [] for column in columns + external_columns: - ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint) + ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) if ( not ancestor or column.table diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b94313a..79a1d90 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -131,6 +131,7 @@ class Parser: TokenType.ALTER, TokenType.ALWAYS, TokenType.ANTI, + TokenType.APPLY, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, @@ -190,6 +191,7 @@ class Parser: TokenType.TABLE, TokenType.TABLE_FORMAT, TokenType.TEMPORARY, + TokenType.TRANSIENT, TokenType.TOP, TokenType.TRAILING, TokenType.TRUNCATE, @@ -204,7 +206,7 @@ class Parser: *TYPE_TOKENS, } - TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} @@ -685,6 +687,7 @@ class Parser: def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) + transient = self._match(TokenType.TRANSIENT) unique = self._match(TokenType.UNIQUE) materialized = self._match(TokenType.MATERIALIZED) @@ -723,6 +726,7 @@ class Parser: exists=exists, properties=properties, temporary=temporary, + transient=transient, replace=replace, unique=unique, materialized=materialized, @@ -1057,8 +1061,8 @@ class Parser: return self._parse_set_operations(this) if this else None - def _parse_with(self): - if not self._match(TokenType.WITH): + def _parse_with(self, skip_with_token=False): + if not skip_with_token and not self._match(TokenType.WITH): return None recursive = self._match(TokenType.RECURSIVE) @@ -1167,28 +1171,53 @@ class Parser: return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) def _parse_lateral(self): - if not self._match(TokenType.LATERAL): + outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) + cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) + + if outer_apply or cross_apply: + this = self._parse_select(table=True) + view = None + outer = not cross_apply + elif self._match(TokenType.LATERAL): + this = self._parse_select(table=True) + view = self._match(TokenType.VIEW) + outer = self._match(TokenType.OUTER) + else: return None - subquery = self._parse_select(table=True) + if not this: + this = self._parse_function() - if subquery: - return self.expression(exp.Lateral, this=subquery) + table_alias = self._parse_id_var(any_token=False) - self._match(TokenType.VIEW) - outer = self._match(TokenType.OUTER) + columns = None + if self._match(TokenType.ALIAS): + columns = self._parse_csv(self._parse_id_var) + elif self._match(TokenType.L_PAREN): + columns = self._parse_csv(self._parse_id_var) + self._match(TokenType.R_PAREN) - return self.expression( + expression = self.expression( exp.Lateral, - this=self._parse_function(), + this=this, + view=view, outer=outer, alias=self.expression( exp.TableAlias, - this=self._parse_id_var(any_token=False), - columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None), + this=table_alias, + columns=columns, ), ) + if outer_apply or cross_apply: + return self.expression( + exp.Join, + this=expression, + side=None if cross_apply else "LEFT", + ) + + return expression + def _parse_join_side_and_kind(self): return ( self._match(TokenType.NATURAL) and self._prev, @@ -1196,10 +1225,10 @@ class Parser: self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self): + def _parse_join(self, skip_join_token=False): natural, side, kind = self._parse_join_side_and_kind() - if not self._match(TokenType.JOIN): + if not skip_join_token and not self._match(TokenType.JOIN): return None kwargs = {"this": self._parse_table()} @@ -1425,13 +1454,13 @@ class Parser: unpivot=unpivot, ) - def _parse_where(self): - if not self._match(TokenType.WHERE): + def _parse_where(self, skip_where_token=False): + if not skip_where_token and not self._match(TokenType.WHERE): return None return self.expression(exp.Where, this=self._parse_conjunction()) - def _parse_group(self): - if not self._match(TokenType.GROUP_BY): + def _parse_group(self, skip_group_by_token=False): + if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None return self.expression( exp.Group, @@ -1457,8 +1486,8 @@ class Parser: return self.expression(exp.Tuple, expressions=grouping_set) return self._parse_id_var() - def _parse_having(self): - if not self._match(TokenType.HAVING): + def _parse_having(self, skip_having_token=False): + if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) @@ -1467,8 +1496,8 @@ class Parser: return None return self.expression(exp.Qualify, this=self._parse_conjunction()) - def _parse_order(self, this=None): - if not self._match(TokenType.ORDER_BY): + def _parse_order(self, this=None, skip_order_token=False): + if not skip_order_token and not self._match(TokenType.ORDER_BY): return this return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) @@ -1502,7 +1531,11 @@ class Parser: def _parse_limit(self, this=None, top=False): if self._match(TokenType.TOP if top else TokenType.LIMIT): - return self.expression(exp.Limit, this=this, expression=self._parse_number()) + limit_paren = self._match(TokenType.L_PAREN) + limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) + if limit_paren: + self._match(TokenType.R_PAREN) + return limit_exp if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -2136,7 +2169,7 @@ class Parser: return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) def _parse_convert(self, strict): - this = self._parse_field() + this = self._parse_column() if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) elif self._match(TokenType.COMMA): diff --git a/sqlglot/time.py b/sqlglot/time.py index de28ac0..729b50d 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -43,5 +43,4 @@ def format_time(string, mapping, trie=None): if result and end > size: chunks.append(chars) - return "".join(mapping.get(chars, chars) for chars in chunks) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 1a9d72e..766c01a 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -107,6 +107,7 @@ class TokenType(AutoName): ANALYZE = auto() ANTI = auto() ANY = auto() + APPLY = auto() ARRAY = auto() ASC = auto() AT_TIME_ZONE = auto() @@ -256,6 +257,7 @@ class TokenType(AutoName): TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() + TRANSIENT = auto() TOP = auto() THEN = auto() TRUE = auto() @@ -560,6 +562,7 @@ class Tokenizer(metaclass=_Tokenizer): "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY, + "TRANSIENT": TokenType.TRANSIENT, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "TRAILING": TokenType.TRAILING, @@ -582,6 +585,7 @@ class Tokenizer(metaclass=_Tokenizer): "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, "WITHIN GROUP": TokenType.WITHIN_GROUP, "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, + "APPLY": TokenType.APPLY, "ARRAY": TokenType.ARRAY, "BOOL": TokenType.BOOLEAN, "BOOLEAN": TokenType.BOOLEAN, -- cgit v1.2.3