diff options
Diffstat (limited to 'sqlglot/dialects/clickhouse.py')
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 279 |
1 files changed, 201 insertions, 78 deletions
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 2a49066..c8a9525 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -3,11 +3,16 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql +from sqlglot.dialects.dialect import ( + Dialect, + inline_array_sql, + no_pivot_sql, + rename_func, + var_map_sql, +) from sqlglot.errors import ParseError -from sqlglot.helper import ensure_list, seq_get from sqlglot.parser import parse_var_map -from sqlglot.tokens import TokenType +from sqlglot.tokens import Token, TokenType def _lower_func(sql: str) -> str: @@ -28,65 +33,122 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ASOF": TokenType.ASOF, - "GLOBAL": TokenType.GLOBAL, - "DATETIME64": TokenType.DATETIME, + "ATTACH": TokenType.COMMAND, + "DATETIME64": TokenType.DATETIME64, "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, - "INT8": TokenType.TINYINT, - "UINT8": TokenType.UTINYINT, + "GLOBAL": TokenType.GLOBAL, + "INT128": TokenType.INT128, "INT16": TokenType.SMALLINT, - "UINT16": TokenType.USMALLINT, + "INT256": TokenType.INT256, "INT32": TokenType.INT, - "UINT32": TokenType.UINT, "INT64": TokenType.BIGINT, - "UINT64": TokenType.UBIGINT, - "INT128": TokenType.INT128, + "INT8": TokenType.TINYINT, + "MAP": TokenType.MAP, + "TUPLE": TokenType.STRUCT, "UINT128": TokenType.UINT128, - "INT256": TokenType.INT256, + "UINT16": TokenType.USMALLINT, "UINT256": TokenType.UINT256, - "TUPLE": TokenType.STRUCT, + "UINT32": TokenType.UINT, + "UINT64": TokenType.UBIGINT, + "UINT8": TokenType.UTINYINT, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg( - this=seq_get(args, 0), - time=seq_get(args, 1), - decay=seq_get(params, 0), - ), - "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray( - this=seq_get(args, 0), size=seq_get(params, 0) - ), - "HISTOGRAM": lambda params, args: exp.Histogram( - this=seq_get(args, 0), bins=seq_get(params, 0) - ), + **parser.Parser.FUNCTIONS, + "ANY": exp.AnyValue.from_arg_list, "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, - "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), - "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), - "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), + "UNIQ": exp.ApproxDistinct.from_arg_list, + } + + FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "QUANTILE": lambda self: self._parse_quantile(), } - FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("MATCH") + NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy() + NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY) + RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN) and self._parse_in(this, is_global=True), } - JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore + # The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to + # the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler. + COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy() + COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER) + + JOIN_KINDS = { + *parser.Parser.JOIN_KINDS, + TokenType.ANY, + TokenType.ASOF, + TokenType.ANTI, + TokenType.SEMI, + } - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { + TokenType.ANY, + TokenType.ASOF, + TokenType.SEMI, + TokenType.ANTI, + TokenType.SETTINGS, + TokenType.FORMAT, + } LOG_DEFAULTS_TO_LN = True - def _parse_in( - self, this: t.Optional[exp.Expression], is_global: bool = False - ) -> exp.Expression: + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + "settings": lambda self: self._parse_csv(self._parse_conjunction) + if self._match(TokenType.SETTINGS) + else None, + "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None, + } + + def _parse_conjunction(self) -> t.Optional[exp.Expression]: + this = super()._parse_conjunction() + + if self._match(TokenType.PLACEHOLDER): + return self.expression( + exp.If, + this=this, + true=self._parse_conjunction(), + false=self._match(TokenType.COLON) and self._parse_conjunction(), + ) + + return this + + def _parse_placeholder(self) -> t.Optional[exp.Expression]: + """ + Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier} + https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters + """ + if not self._match(TokenType.L_BRACE): + return None + + this = self._parse_id_var() + self._match(TokenType.COLON) + kind = self._parse_types(check_func=False) or ( + self._match_text_seq("IDENTIFIER") and "Identifier" + ) + + if not kind: + self.raise_error("Expecting a placeholder type or 'Identifier' for tables") + elif not self._match(TokenType.R_BRACE): + self.raise_error("Expecting }") + + return self.expression(exp.Placeholder, this=this, kind=kind) + + def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In: this = super()._parse_in(this) this.set("is_global", is_global) return this @@ -120,81 +182,142 @@ class ClickHouse(Dialect): return self.expression(exp.CTE, this=statement, alias=statement and statement.this) + def _parse_join_side_and_kind( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: + is_global = self._match(TokenType.GLOBAL) and self._prev + kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev + if kind_pre: + kind = self._match_set(self.JOIN_KINDS) and self._prev + side = self._match_set(self.JOIN_SIDES) and self._prev + return is_global, side, kind + return ( + is_global, + self._match_set(self.JOIN_SIDES) and self._prev, + self._match_set(self.JOIN_KINDS) and self._prev, + ) + + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + join = super()._parse_join(skip_join_token) + + if join: + join.set("global", join.args.pop("natural", None)) + return join + + def _parse_function( + self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + ) -> t.Optional[exp.Expression]: + func = super()._parse_function(functions, anonymous) + + if isinstance(func, exp.Anonymous): + params = self._parse_func_params(func) + + if params: + return self.expression( + exp.ParameterizedAgg, + this=func.this, + expressions=func.expressions, + params=params, + ) + + return func + + def _parse_func_params( + self, this: t.Optional[exp.Func] = None + ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): + return self._parse_csv(self._parse_lambda) + if self._match(TokenType.L_PAREN): + params = self._parse_csv(self._parse_lambda) + self._match_r_paren(this) + return params + return None + + def _parse_quantile(self) -> exp.Quantile: + this = self._parse_lambda() + params = self._parse_func_params() + if params: + return self.expression(exp.Quantile, this=params[0], quantile=this) + return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5)) + + def _parse_wrapped_id_vars( + self, optional: bool = False + ) -> t.List[t.Optional[exp.Expression]]: + return super()._parse_wrapped_id_vars(optional=True) + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.NULLABLE: "Nullable", - exp.DataType.Type.DATETIME: "DateTime64", - exp.DataType.Type.MAP: "Map", + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", + exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.DATETIME64: "DateTime64", + exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.FLOAT: "Float32", + exp.DataType.Type.INT: "Int32", + exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.INT256: "Int256", + exp.DataType.Type.MAP: "Map", + exp.DataType.Type.NULLABLE: "Nullable", + exp.DataType.Type.SMALLINT: "Int16", exp.DataType.Type.STRUCT: "Tuple", exp.DataType.Type.TINYINT: "Int8", - exp.DataType.Type.UTINYINT: "UInt8", - exp.DataType.Type.SMALLINT: "Int16", - exp.DataType.Type.USMALLINT: "UInt16", - exp.DataType.Type.INT: "Int32", - exp.DataType.Type.UINT: "UInt32", - exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.UBIGINT: "UInt64", - exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.UINT: "UInt32", exp.DataType.Type.UINT128: "UInt128", - exp.DataType.Type.INT256: "Int256", exp.DataType.Type.UINT256: "UInt256", - exp.DataType.Type.FLOAT: "Float32", - exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.USMALLINT: "UInt16", + exp.DataType.Type.UTINYINT: "UInt8", } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, + exp.AnyValue: rename_func("any"), + exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, - exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}", + exp.CastToStrType: rename_func("CAST"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", - exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}", - exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), - exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", - exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", - exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Pivot: no_pivot_sql, + exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile")) + + f"({self.sql(e, 'this')})", exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, } JOIN_HINTS = False TABLE_HINTS = False EXPLICIT_UNION = True - - def _param_args_sql( - self, - expression: exp.Expression, - param_names: str | t.List[str], - arg_names: str | t.List[str], - ) -> str: - params = self.format_args( - *( - arg - for name in ensure_list(param_names) - for arg in ensure_list(expression.args.get(name)) - ) - ) - args = self.format_args( - *( - arg - for name in ensure_list(arg_names) - for arg in ensure_list(expression.args.get(name)) - ) - ) - return f"({params})({args})" + GROUPINGS_SEP = "" def cte_sql(self, expression: exp.CTE) -> str: if isinstance(expression.this, exp.Alias): return self.sql(expression, "this") return super().cte_sql(expression) + + def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: + return super().after_limit_modifiers(expression) + [ + self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) + if expression.args.get("settings") + else "", + self.seg("FORMAT ") + self.sql(expression, "format") + if expression.args.get("format") + else "", + ] + + def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: + params = self.expressions(expression, "params", flat=True) + return self.func(expression.name, *expression.expressions) + f"({params})" + + def placeholder_sql(self, expression: exp.Placeholder) -> str: + return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" |