diff options
-rw-r--r-- | sqlglot/__init__.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 6 | ||||
-rw-r--r-- | sqlglot/expressions.py | 22 | ||||
-rw-r--r-- | sqlglot/generator.py | 8 | ||||
-rw-r--r-- | sqlglot/parser.py | 27 | ||||
-rw-r--r-- | sqlglot/schema.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 24 | ||||
-rw-r--r-- | tests/test_build.py | 6 | ||||
-rw-r--r-- | tests/test_parser.py | 2 | ||||
-rw-r--r-- | tests/test_schema.py | 5 |
13 files changed, 91 insertions, 26 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 3733b20..e829517 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -30,7 +30,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.2.6" +__version__ = "10.2.9" pretty = False diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 70c1c6c..8d6e1ae 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -250,6 +250,7 @@ class Hive(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 1cb5025..f276af1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -244,6 +244,7 @@ class Postgres(Dialect): class Parser(parser.Parser): STRICT_CAST = False + LATERAL_FUNCTION_AS_VIEW = True FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 07ce38b..a552e7b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -224,6 +224,12 @@ class TSQL(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] + QUOTES = [ + (prefix + quote, quote) if prefix else quote + for quote in ["'", '"'] + for prefix in ["", "n", "N"] + ] + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "BIT": TokenType.BOOLEAN, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7249574..aeed218 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3673,7 +3673,11 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: ) -def values(values, alias=None) -> Values: +def values( + values: t.Iterable[t.Tuple[t.Any, ...]], + alias: t.Optional[str] = None, + columns: t.Optional[t.Iterable[str]] = None, +) -> Values: """Build VALUES statement. Example: @@ -3681,17 +3685,23 @@ def values(values, alias=None) -> Values: "VALUES (1, '2')" Args: - values (list[tuple[str | Expression]]): values statements that will be converted to SQL - alias (str): optional alias - dialect (str): the dialect used to parse the input expression. - **opts: other options to use to parse the input expressions. + values: values statements that will be converted to SQL + alias: optional alias + columns: Optional list of ordered column names. An alias is required when providing column names. Returns: Values: the Values expression object """ + if columns and not alias: + raise ValueError("Alias is required when providing columns") + table_alias = ( + TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns]) + if columns + else TableAlias(this=to_identifier(alias) if alias else None) + ) return Values( expressions=[convert(tup) for tup in values], - alias=to_identifier(alias) if alias else None, + alias=table_alias, ) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index beffb91..2b4c575 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -795,14 +795,16 @@ class Generator: alias = expression.args["alias"] table = alias.name - table = f" {table}" if table else table columns = self.expressions(alias, key="columns", flat=True) - columns = f" AS {columns}" if columns else "" if expression.args.get("view"): + table = f" {table}" if table else table + columns = f" AS {columns}" if columns else "" op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") return f"{op_sql}{self.sep()}{this}{table}{columns}" + table = f" AS {table}" if table else table + columns = f"({columns})" if columns else "" return f"LATERAL {this}{table}{columns}" def limit_sql(self, expression: exp.Limit) -> str: @@ -889,8 +891,8 @@ class Generator: def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("laterals", [])], *[self.sql(sql) for sql in expression.args.get("joins", [])], + *[self.sql(sql) for sql in expression.args.get("laterals", [])], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 55ab453..29bc9c0 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -562,6 +562,7 @@ class Parser(metaclass=_Parser): TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} STRICT_CAST = True + LATERAL_FUNCTION_AS_VIEW = False __slots__ = ( "error_level", @@ -1287,16 +1288,26 @@ class Parser(metaclass=_Parser): return None if not this: - this = self._parse_function() - - table_alias = self._parse_id_var(any_token=False) + this = self._parse_function() or self._parse_id_var(any_token=False) + while self._match(TokenType.DOT): + this = exp.Dot( + this=this, + expression=self._parse_function() or self._parse_id_var(any_token=False), + ) columns = None - if self._match(TokenType.ALIAS): - columns = self._parse_csv(self._parse_id_var) - elif self._match(TokenType.L_PAREN): - columns = self._parse_csv(self._parse_id_var) - self._match_r_paren() + table_alias = None + if view or self.LATERAL_FUNCTION_AS_VIEW: + table_alias = self._parse_id_var(any_token=False) + if self._match(TokenType.ALIAS): + columns = self._parse_csv(self._parse_id_var) + else: + self._match(TokenType.ALIAS) + table_alias = self._parse_id_var(any_token=False) + + if self._match(TokenType.L_PAREN): + columns = self._parse_csv(self._parse_id_var) + self._match_r_paren() expression = self.expression( exp.Lateral, diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 8a264a2..c223ee0 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -237,12 +237,17 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): if table_: table_schema = self.find(table_, raise_on_missing=False) if table_schema: - schema_type = table_schema.get(column_name).upper() # type: ignore - return self._convert_type(schema_type) + column_type = table_schema.get(column_name) + + if isinstance(column_type, exp.DataType): + return column_type + elif isinstance(column_type, str): + return self._to_data_type(column_type.upper()) + raise SchemaError(f"Unknown column type '{column_type}'") return exp.DataType(this=exp.DataType.Type.UNKNOWN) raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type: str) -> exp.DataType: + def _to_data_type(self, schema_type: str) -> exp.DataType: """ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index bca5aaa..e3d0cff 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -496,7 +496,7 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f.value AS "Contact", f1.value['type'] AS "Type", f1.value['content'] AS "Details" -FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""", +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERAL FLATTEN(input => f.value['business']) AS f1""", }, pretty=True, ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index afdd48a..e4c6e60 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -371,13 +371,19 @@ class TestTSQL(Validator): self.validate_all( "SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)", write={ - "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) y AS z", + "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) AS y(z)", }, ) self.validate_all( "SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)", write={ - "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) y AS z", + "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) AS y(z)", + }, + ) + self.validate_all( + "SELECT t.x, y.z FROM x OUTER APPLY a.b.tvfTest(t.x)y(z)", + write={ + "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL a.b.TVFTEST(t.x) AS y(z)", }, ) @@ -421,3 +427,17 @@ class TestTSQL(Validator): self.validate_all( "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} ) + + def test_string(self): + self.validate_all( + "SELECT N'test'", + write={"spark": "SELECT 'test'"}, + ) + self.validate_all( + "SELECT n'test'", + write={"spark": "SELECT 'test'"}, + ) + self.validate_all( + "SELECT '''test'''", + write={"spark": r"SELECT '\'test\''"}, + ) diff --git a/tests/test_build.py b/tests/test_build.py index 721c868..b014a3a 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -473,6 +473,12 @@ class TestBuild(unittest.TestCase): (lambda: exp.values([("1", 2)]), "VALUES ('1', 2)"), (lambda: exp.values([("1", 2)], "alias"), "(VALUES ('1', 2)) AS alias"), (lambda: exp.values([("1", 2), ("2", 3)]), "VALUES ('1', 2), ('2', 3)"), + ( + lambda: exp.values( + [("1", 2, None), ("2", 3, None)], "alias", ["col1", "col2", "col3"] + ), + "(VALUES ('1', 2, NULL), ('2', 3, NULL)) AS alias(col1, col2, col3)", + ), (lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"), (lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"), ]: diff --git a/tests/test_parser.py b/tests/test_parser.py index fa7b589..0be15e4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -85,7 +85,7 @@ class TestParser(unittest.TestCase): self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1) self.assertEqual( parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), - """SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""", + """SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""", ) def test_command(self): diff --git a/tests/test_schema.py b/tests/test_schema.py index f1e12a2..6c1ca9c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,6 +1,6 @@ import unittest -from sqlglot import exp, to_table +from sqlglot import exp, parse_one, to_table from sqlglot.errors import SchemaError from sqlglot.schema import MappingSchema, ensure_schema @@ -181,3 +181,6 @@ class TestSchema(unittest.TestCase): schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this, exp.DataType.Type.VARCHAR, ) + + schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}}) + self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT) |