sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
20class SparkSession: 21 known_ids: t.ClassVar[t.Set[str]] = set() 22 known_branch_ids: t.ClassVar[t.Set[str]] = set() 23 known_sequence_ids: t.ClassVar[t.Set[str]] = set() 24 name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) 25 26 def __init__(self): 27 self.incrementing_id = 1 28 29 def __getattr__(self, name: str) -> SparkSession: 30 return self 31 32 def __call__(self, *args, **kwargs) -> SparkSession: 33 return self 34 35 @property 36 def read(self) -> DataFrameReader: 37 return DataFrameReader(self) 38 39 def table(self, tableName: str) -> DataFrame: 40 return self.read.table(tableName) 41 42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 expressions=[ 92 exp.Values( 93 expressions=data_expressions, 94 alias=exp.TableAlias( 95 this=exp.to_identifier(self._auto_incrementing_name), 96 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 97 ), 98 ), 99 ], 100 ), 101 } 102 103 sel_expression = exp.Select(**select_kwargs) 104 return DataFrame(self, sel_expression) 105 106 def sql(self, sqlQuery: str) -> DataFrame: 107 expression = sqlglot.parse_one(sqlQuery, read="spark") 108 if isinstance(expression, exp.Select): 109 df = DataFrame(self, expression) 110 df = df._convert_leaf_to_cte() 111 elif isinstance(expression, (exp.Create, exp.Insert)): 112 select_expression = expression.expression.copy() 113 if isinstance(expression, exp.Insert): 114 select_expression.set("with", expression.args.get("with")) 115 expression.set("with", None) 116 del expression.args["expression"] 117 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 118 df = df._convert_leaf_to_cte() 119 else: 120 raise ValueError( 121 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 122 ) 123 return df 124 125 @property 126 def _auto_incrementing_name(self) -> str: 127 name = f"a{self.incrementing_id}" 128 self.incrementing_id += 1 129 return name 130 131 @property 132 def _random_name(self) -> str: 133 return "r" + uuid.uuid4().hex 134 135 @property 136 def _random_branch_id(self) -> str: 137 id = self._random_id 138 self.known_branch_ids.add(id) 139 return id 140 141 @property 142 def _random_sequence_id(self): 143 id = self._random_id 144 self.known_sequence_ids.add(id) 145 return id 146 147 @property 148 def _random_id(self) -> str: 149 id = self._random_name 150 self.known_ids.add(id) 151 return id 152 153 @property 154 def _join_hint_names(self) -> t.Set[str]: 155 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 156 157 def _add_alias_to_mapping(self, name: str, sequence_id: str): 158 self.name_to_sequence_id_mapping[name].append(sequence_id)
42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 expressions=[ 92 exp.Values( 93 expressions=data_expressions, 94 alias=exp.TableAlias( 95 this=exp.to_identifier(self._auto_incrementing_name), 96 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 97 ), 98 ), 99 ], 100 ), 101 } 102 103 sel_expression = exp.Select(**select_kwargs) 104 return DataFrame(self, sel_expression)
106 def sql(self, sqlQuery: str) -> DataFrame: 107 expression = sqlglot.parse_one(sqlQuery, read="spark") 108 if isinstance(expression, exp.Select): 109 df = DataFrame(self, expression) 110 df = df._convert_leaf_to_cte() 111 elif isinstance(expression, (exp.Create, exp.Insert)): 112 select_expression = expression.expression.copy() 113 if isinstance(expression, exp.Insert): 114 select_expression.set("with", expression.args.get("with")) 115 expression.set("with", None) 116 del expression.args["expression"] 117 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 118 df = df._convert_leaf_to_cte() 119 else: 120 raise ValueError( 121 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 122 ) 123 return df
45class DataFrame: 46 def __init__( 47 self, 48 spark: SparkSession, 49 expression: exp.Select, 50 branch_id: t.Optional[str] = None, 51 sequence_id: t.Optional[str] = None, 52 last_op: Operation = Operation.INIT, 53 pending_hints: t.Optional[t.List[exp.Expression]] = None, 54 output_expression_container: t.Optional[OutputExpressionContainer] = None, 55 **kwargs, 56 ): 57 self.spark = spark 58 self.expression = expression 59 self.branch_id = branch_id or self.spark._random_branch_id 60 self.sequence_id = sequence_id or self.spark._random_sequence_id 61 self.last_op = last_op 62 self.pending_hints = pending_hints or [] 63 self.output_expression_container = output_expression_container or exp.Select() 64 65 def __getattr__(self, column_name: str) -> Column: 66 return self[column_name] 67 68 def __getitem__(self, column_name: str) -> Column: 69 column_name = f"{self.branch_id}.{column_name}" 70 return Column(column_name) 71 72 def __copy__(self): 73 return self.copy() 74 75 @property 76 def sparkSession(self): 77 return self.spark 78 79 @property 80 def write(self): 81 return DataFrameWriter(self) 82 83 @property 84 def latest_cte_name(self) -> str: 85 if not self.expression.ctes: 86 from_exp = self.expression.args["from"] 87 if from_exp.alias_or_name: 88 return from_exp.alias_or_name 89 table_alias = from_exp.find(exp.TableAlias) 90 if not table_alias: 91 raise RuntimeError( 92 f"Could not find an alias name for this expression: {self.expression}" 93 ) 94 return table_alias.alias_or_name 95 return self.expression.ctes[-1].alias 96 97 @property 98 def pending_join_hints(self): 99 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 100 101 @property 102 def pending_partition_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 104 105 @property 106 def columns(self) -> t.List[str]: 107 return self.expression.named_selects 108 109 @property 110 def na(self) -> DataFrameNaFunctions: 111 return DataFrameNaFunctions(self) 112 113 def _replace_cte_names_with_hashes(self, expression: exp.Select): 114 replacement_mapping = {} 115 for cte in expression.ctes: 116 old_name_id = cte.args["alias"].this 117 new_hashed_id = exp.to_identifier( 118 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 119 ) 120 replacement_mapping[old_name_id] = new_hashed_id 121 expression = expression.transform(replace_id_value, replacement_mapping) 122 return expression 123 124 def _create_cte_from_expression( 125 self, 126 expression: exp.Expression, 127 branch_id: t.Optional[str] = None, 128 sequence_id: t.Optional[str] = None, 129 **kwargs, 130 ) -> t.Tuple[exp.CTE, str]: 131 name = self.spark._random_name 132 expression_to_cte = expression.copy() 133 expression_to_cte.set("with", None) 134 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 135 cte.set("branch_id", branch_id or self.branch_id) 136 cte.set("sequence_id", sequence_id or self.sequence_id) 137 return cte, name 138 139 @t.overload 140 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 141 ... 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 145 ... 146 147 def _ensure_list_of_columns(self, cols): 148 return Column.ensure_cols(ensure_list(cols)) 149 150 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 151 cols = self._ensure_list_of_columns(cols) 152 normalize(self.spark, expression or self.expression, cols) 153 return cols 154 155 def _ensure_and_normalize_col(self, col): 156 col = Column.ensure_col(col) 157 normalize(self.spark, self.expression, col) 158 return col 159 160 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 161 df = self._resolve_pending_hints() 162 sequence_id = sequence_id or df.sequence_id 163 expression = df.expression.copy() 164 cte_expression, cte_name = df._create_cte_from_expression( 165 expression=expression, sequence_id=sequence_id 166 ) 167 new_expression = df._add_ctes_to_expression( 168 exp.Select(), expression.ctes + [cte_expression] 169 ) 170 sel_columns = df._get_outer_select_columns(cte_expression) 171 new_expression = new_expression.from_(cte_name).select( 172 *[x.alias_or_name for x in sel_columns] 173 ) 174 return df.copy(expression=new_expression, sequence_id=sequence_id) 175 176 def _resolve_pending_hints(self) -> DataFrame: 177 df = self.copy() 178 if not self.pending_hints: 179 return df 180 expression = df.expression 181 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 182 for hint in df.pending_partition_hints: 183 hint_expression.append("expressions", hint) 184 df.pending_hints.remove(hint) 185 186 join_aliases = { 187 join_table.alias_or_name 188 for join_table in get_tables_from_expression_with_join(expression) 189 } 190 if join_aliases: 191 for hint in df.pending_join_hints: 192 for sequence_id_expression in hint.expressions: 193 sequence_id_or_name = sequence_id_expression.alias_or_name 194 sequence_ids_to_match = [sequence_id_or_name] 195 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 196 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 197 sequence_id_or_name 198 ] 199 matching_ctes = [ 200 cte 201 for cte in reversed(expression.ctes) 202 if cte.args["sequence_id"] in sequence_ids_to_match 203 ] 204 for matching_cte in matching_ctes: 205 if matching_cte.alias_or_name in join_aliases: 206 sequence_id_expression.set("this", matching_cte.args["alias"].this) 207 df.pending_hints.remove(hint) 208 break 209 hint_expression.append("expressions", hint) 210 if hint_expression.expressions: 211 expression.set("hint", hint_expression) 212 return df 213 214 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 215 hint_name = hint_name.upper() 216 hint_expression = ( 217 exp.JoinHint( 218 this=hint_name, 219 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 220 ) 221 if hint_name in JOIN_HINTS 222 else exp.Anonymous( 223 this=hint_name, expressions=[parameter.expression for parameter in args] 224 ) 225 ) 226 new_df = self.copy() 227 new_df.pending_hints.append(hint_expression) 228 return new_df 229 230 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 231 other_df = other._convert_leaf_to_cte() 232 base_expression = self.expression.copy() 233 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 234 all_ctes = base_expression.ctes 235 other_df.expression.set("with", None) 236 base_expression.set("with", None) 237 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 238 operation.set("with", exp.With(expressions=all_ctes)) 239 return self.copy(expression=operation)._convert_leaf_to_cte() 240 241 def _cache(self, storage_level: str): 242 df = self._convert_leaf_to_cte() 243 df.expression.ctes[-1].set("cache_storage_level", storage_level) 244 return df 245 246 @classmethod 247 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 248 expression = expression.copy() 249 with_expression = expression.args.get("with") 250 if with_expression: 251 existing_ctes = with_expression.expressions 252 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 253 for cte in ctes: 254 if cte.alias_or_name not in existsing_cte_names: 255 existing_ctes.append(cte) 256 else: 257 existing_ctes = ctes 258 expression.set("with", exp.With(expressions=existing_ctes)) 259 return expression 260 261 @classmethod 262 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 263 expression = item.expression if isinstance(item, DataFrame) else item 264 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 265 266 @classmethod 267 def _create_hash_from_expression(cls, expression: exp.Select): 268 value = expression.sql(dialect="spark").encode("utf-8") 269 return f"t{zlib.crc32(value)}"[:6] 270 271 def _get_select_expressions( 272 self, 273 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 274 select_expressions: t.List[ 275 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 276 ] = [] 277 main_select_ctes: t.List[exp.CTE] = [] 278 for cte in self.expression.ctes: 279 cache_storage_level = cte.args.get("cache_storage_level") 280 if cache_storage_level: 281 select_expression = cte.this.copy() 282 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 283 select_expression.set("cte_alias_name", cte.alias_or_name) 284 select_expression.set("cache_storage_level", cache_storage_level) 285 select_expressions.append((exp.Cache, select_expression)) 286 else: 287 main_select_ctes.append(cte) 288 main_select = self.expression.copy() 289 if main_select_ctes: 290 main_select.set("with", exp.With(expressions=main_select_ctes)) 291 expression_select_pair = (type(self.output_expression_container), main_select) 292 select_expressions.append(expression_select_pair) # type: ignore 293 return select_expressions 294 295 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 296 df = self._resolve_pending_hints() 297 select_expressions = df._get_select_expressions() 298 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 299 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 300 for expression_type, select_expression in select_expressions: 301 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 302 if optimize: 303 select_expression = optimize_func(select_expression, identify="always") 304 select_expression = df._replace_cte_names_with_hashes(select_expression) 305 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 306 if expression_type == exp.Cache: 307 cache_table_name = df._create_hash_from_expression(select_expression) 308 cache_table = exp.to_table(cache_table_name) 309 original_alias_name = select_expression.args["cte_alias_name"] 310 311 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 312 cache_table_name 313 ) 314 sqlglot.schema.add_table( 315 cache_table_name, 316 { 317 expression.alias_or_name: expression.type.sql("spark") 318 for expression in select_expression.expressions 319 }, 320 ) 321 cache_storage_level = select_expression.args["cache_storage_level"] 322 options = [ 323 exp.Literal.string("storageLevel"), 324 exp.Literal.string(cache_storage_level), 325 ] 326 expression = exp.Cache( 327 this=cache_table, expression=select_expression, lazy=True, options=options 328 ) 329 # We will drop the "view" if it exists before running the cache table 330 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 331 elif expression_type == exp.Create: 332 expression = df.output_expression_container.copy() 333 expression.set("expression", select_expression) 334 elif expression_type == exp.Insert: 335 expression = df.output_expression_container.copy() 336 select_without_ctes = select_expression.copy() 337 select_without_ctes.set("with", None) 338 expression.set("expression", select_without_ctes) 339 if select_expression.ctes: 340 expression.set("with", exp.With(expressions=select_expression.ctes)) 341 elif expression_type == exp.Select: 342 expression = select_expression 343 else: 344 raise ValueError(f"Invalid expression type: {expression_type}") 345 output_expressions.append(expression) 346 347 return [ 348 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 349 ] 350 351 def copy(self, **kwargs) -> DataFrame: 352 return DataFrame(**object_to_dict(self, **kwargs)) 353 354 @operation(Operation.SELECT) 355 def select(self, *cols, **kwargs) -> DataFrame: 356 cols = self._ensure_and_normalize_cols(cols) 357 kwargs["append"] = kwargs.get("append", False) 358 if self.expression.args.get("joins"): 359 ambiguous_cols = [ 360 col 361 for col in cols 362 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 363 ] 364 if ambiguous_cols: 365 join_table_identifiers = [ 366 x.this for x in get_tables_from_expression_with_join(self.expression) 367 ] 368 cte_names_in_join = [x.this for x in join_table_identifiers] 369 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 370 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 371 # of Spark. 372 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 373 for ambiguous_col in ambiguous_cols: 374 ctes_with_column = [ 375 cte 376 for cte in self.expression.ctes 377 if cte.alias_or_name in cte_names_in_join 378 and ambiguous_col.alias_or_name in cte.this.named_selects 379 ] 380 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 381 # use the same CTE we used before 382 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 383 if cte: 384 resolved_column_position[ambiguous_col] += 1 385 else: 386 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 387 ambiguous_col.expression.set("table", cte.alias_or_name) 388 return self.copy( 389 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 390 ) 391 392 @operation(Operation.NO_OP) 393 def alias(self, name: str, **kwargs) -> DataFrame: 394 new_sequence_id = self.spark._random_sequence_id 395 df = self.copy() 396 for join_hint in df.pending_join_hints: 397 for expression in join_hint.expressions: 398 if expression.alias_or_name == self.sequence_id: 399 expression.set("this", Column.ensure_col(new_sequence_id).expression) 400 df.spark._add_alias_to_mapping(name, new_sequence_id) 401 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 402 403 @operation(Operation.WHERE) 404 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 405 col = self._ensure_and_normalize_col(column) 406 return self.copy(expression=self.expression.where(col.expression)) 407 408 filter = where 409 410 @operation(Operation.GROUP_BY) 411 def groupBy(self, *cols, **kwargs) -> GroupedData: 412 columns = self._ensure_and_normalize_cols(cols) 413 return GroupedData(self, columns, self.last_op) 414 415 @operation(Operation.SELECT) 416 def agg(self, *exprs, **kwargs) -> DataFrame: 417 cols = self._ensure_and_normalize_cols(exprs) 418 return self.groupBy().agg(*cols) 419 420 @operation(Operation.FROM) 421 def join( 422 self, 423 other_df: DataFrame, 424 on: t.Union[str, t.List[str], Column, t.List[Column]], 425 how: str = "inner", 426 **kwargs, 427 ) -> DataFrame: 428 other_df = other_df._convert_leaf_to_cte() 429 join_columns = self._ensure_list_of_columns(on) 430 # We will determine actual "join on" expression later so we don't provide it at first 431 join_expression = self.expression.join( 432 other_df.latest_cte_name, join_type=how.replace("_", " ") 433 ) 434 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 435 self_columns = self._get_outer_select_columns(join_expression) 436 other_columns = self._get_outer_select_columns(other_df) 437 # Determines the join clause and select columns to be used passed on what type of columns were provided for 438 # the join. The columns returned changes based on how the on expression is provided. 439 if isinstance(join_columns[0].expression, exp.Column): 440 """ 441 Unique characteristics of join on column names only: 442 * The column names are put at the front of the select list 443 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 444 """ 445 table_names = [ 446 table.alias_or_name 447 for table in get_tables_from_expression_with_join(join_expression) 448 ] 449 potential_ctes = [ 450 cte 451 for cte in join_expression.ctes 452 if cte.alias_or_name in table_names 453 and cte.alias_or_name != other_df.latest_cte_name 454 ] 455 # Determine the table to reference for the left side of the join by checking each of the left side 456 # tables and see if they have the column being referenced. 457 join_column_pairs = [] 458 for join_column in join_columns: 459 num_matching_ctes = 0 460 for cte in potential_ctes: 461 if join_column.alias_or_name in cte.this.named_selects: 462 left_column = join_column.copy().set_table_name(cte.alias_or_name) 463 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 464 join_column_pairs.append((left_column, right_column)) 465 num_matching_ctes += 1 466 if num_matching_ctes > 1: 467 raise ValueError( 468 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 469 ) 470 elif num_matching_ctes == 0: 471 raise ValueError( 472 f"Column {join_column.alias_or_name} does not exist in any of the tables." 473 ) 474 join_clause = functools.reduce( 475 lambda x, y: x & y, 476 [left_column == right_column for left_column, right_column in join_column_pairs], 477 ) 478 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 479 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 480 select_column_names = [ 481 column.alias_or_name 482 if not isinstance(column.expression.this, exp.Star) 483 else column.sql() 484 for column in self_columns + other_columns 485 ] 486 select_column_names = [ 487 column_name 488 for column_name in select_column_names 489 if column_name not in join_column_names 490 ] 491 select_column_names = join_column_names + select_column_names 492 else: 493 """ 494 Unique characteristics of join on expressions: 495 * There is no deduplication of the results. 496 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 497 """ 498 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 499 if len(join_columns) > 1: 500 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 501 join_clause = join_columns[0] 502 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 503 504 # Update the on expression with the actual join clause to replace the dummy one from before 505 join_expression.args["joins"][-1].set("on", join_clause.expression) 506 new_df = self.copy(expression=join_expression) 507 new_df.pending_join_hints.extend(self.pending_join_hints) 508 new_df.pending_hints.extend(other_df.pending_hints) 509 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 510 return new_df 511 512 @operation(Operation.ORDER_BY) 513 def orderBy( 514 self, 515 *cols: t.Union[str, Column], 516 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 517 ) -> DataFrame: 518 """ 519 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 520 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 521 is unlikely to come up. 522 """ 523 columns = self._ensure_and_normalize_cols(cols) 524 pre_ordered_col_indexes = [ 525 x 526 for x in [ 527 i if isinstance(col.expression, exp.Ordered) else None 528 for i, col in enumerate(columns) 529 ] 530 if x is not None 531 ] 532 if ascending is None: 533 ascending = [True] * len(columns) 534 elif not isinstance(ascending, list): 535 ascending = [ascending] * len(columns) 536 ascending = [bool(x) for i, x in enumerate(ascending)] 537 assert len(columns) == len( 538 ascending 539 ), "The length of items in ascending must equal the number of columns provided" 540 col_and_ascending = list(zip(columns, ascending)) 541 order_by_columns = [ 542 exp.Ordered(this=col.expression, desc=not asc) 543 if i not in pre_ordered_col_indexes 544 else columns[i].column_expression 545 for i, (col, asc) in enumerate(col_and_ascending) 546 ] 547 return self.copy(expression=self.expression.order_by(*order_by_columns)) 548 549 sort = orderBy 550 551 @operation(Operation.FROM) 552 def union(self, other: DataFrame) -> DataFrame: 553 return self._set_operation(exp.Union, other, False) 554 555 unionAll = union 556 557 @operation(Operation.FROM) 558 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 559 l_columns = self.columns 560 r_columns = other.columns 561 if not allowMissingColumns: 562 l_expressions = l_columns 563 r_expressions = l_columns 564 else: 565 l_expressions = [] 566 r_expressions = [] 567 r_columns_unused = copy(r_columns) 568 for l_column in l_columns: 569 l_expressions.append(l_column) 570 if l_column in r_columns: 571 r_expressions.append(l_column) 572 r_columns_unused.remove(l_column) 573 else: 574 r_expressions.append(exp.alias_(exp.Null(), l_column)) 575 for r_column in r_columns_unused: 576 l_expressions.append(exp.alias_(exp.Null(), r_column)) 577 r_expressions.append(r_column) 578 r_df = ( 579 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 580 ) 581 l_df = self.copy() 582 if allowMissingColumns: 583 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 584 return l_df._set_operation(exp.Union, r_df, False) 585 586 @operation(Operation.FROM) 587 def intersect(self, other: DataFrame) -> DataFrame: 588 return self._set_operation(exp.Intersect, other, True) 589 590 @operation(Operation.FROM) 591 def intersectAll(self, other: DataFrame) -> DataFrame: 592 return self._set_operation(exp.Intersect, other, False) 593 594 @operation(Operation.FROM) 595 def exceptAll(self, other: DataFrame) -> DataFrame: 596 return self._set_operation(exp.Except, other, False) 597 598 @operation(Operation.SELECT) 599 def distinct(self) -> DataFrame: 600 return self.copy(expression=self.expression.distinct()) 601 602 @operation(Operation.SELECT) 603 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 604 if not subset: 605 return self.distinct() 606 column_names = ensure_list(subset) 607 window = Window.partitionBy(*column_names).orderBy(*column_names) 608 return ( 609 self.copy() 610 .withColumn("row_num", F.row_number().over(window)) 611 .where(F.col("row_num") == F.lit(1)) 612 .drop("row_num") 613 ) 614 615 @operation(Operation.FROM) 616 def dropna( 617 self, 618 how: str = "any", 619 thresh: t.Optional[int] = None, 620 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 621 ) -> DataFrame: 622 minimum_non_null = thresh or 0 # will be determined later if thresh is null 623 new_df = self.copy() 624 all_columns = self._get_outer_select_columns(new_df.expression) 625 if subset: 626 null_check_columns = self._ensure_and_normalize_cols(subset) 627 else: 628 null_check_columns = all_columns 629 if thresh is None: 630 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 631 else: 632 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 633 if minimum_num_nulls > len(null_check_columns): 634 raise RuntimeError( 635 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 636 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 637 ) 638 if_null_checks = [ 639 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 640 ] 641 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 642 num_nulls = nulls_added_together.alias("num_nulls") 643 new_df = new_df.select(num_nulls, append=True) 644 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 645 final_df = filtered_df.select(*all_columns) 646 return final_df 647 648 @operation(Operation.FROM) 649 def fillna( 650 self, 651 value: t.Union[ColumnLiterals], 652 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 653 ) -> DataFrame: 654 """ 655 Functionality Difference: If you provide a value to replace a null and that type conflicts 656 with the type of the column then PySpark will just ignore your replacement. 657 This will try to cast them to be the same in some cases. So they won't always match. 658 Best to not mix types so make sure replacement is the same type as the column 659 660 Possibility for improvement: Use `typeof` function to get the type of the column 661 and check if it matches the type of the value provided. If not then make it null. 662 """ 663 from sqlglot.dataframe.sql.functions import lit 664 665 values = None 666 columns = None 667 new_df = self.copy() 668 all_columns = self._get_outer_select_columns(new_df.expression) 669 all_column_mapping = {column.alias_or_name: column for column in all_columns} 670 if isinstance(value, dict): 671 values = list(value.values()) 672 columns = self._ensure_and_normalize_cols(list(value)) 673 if not columns: 674 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 675 if not values: 676 values = [value] * len(columns) 677 value_columns = [lit(value) for value in values] 678 679 null_replacement_mapping = { 680 column.alias_or_name: ( 681 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 682 ) 683 for column, value in zip(columns, value_columns) 684 } 685 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 686 null_replacement_columns = [ 687 null_replacement_mapping[column.alias_or_name] for column in all_columns 688 ] 689 new_df = new_df.select(*null_replacement_columns) 690 return new_df 691 692 @operation(Operation.FROM) 693 def replace( 694 self, 695 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 696 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 697 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 698 ) -> DataFrame: 699 from sqlglot.dataframe.sql.functions import lit 700 701 old_values = None 702 new_df = self.copy() 703 all_columns = self._get_outer_select_columns(new_df.expression) 704 all_column_mapping = {column.alias_or_name: column for column in all_columns} 705 706 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 707 if isinstance(to_replace, dict): 708 old_values = list(to_replace) 709 new_values = list(to_replace.values()) 710 elif not old_values and isinstance(to_replace, list): 711 assert isinstance(value, list), "value must be a list since the replacements are a list" 712 assert len(to_replace) == len( 713 value 714 ), "the replacements and values must be the same length" 715 old_values = to_replace 716 new_values = value 717 else: 718 old_values = [to_replace] * len(columns) 719 new_values = [value] * len(columns) 720 old_values = [lit(value) for value in old_values] 721 new_values = [lit(value) for value in new_values] 722 723 replacement_mapping = {} 724 for column in columns: 725 expression = Column(None) 726 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 727 if i == 0: 728 expression = F.when(column == old_value, new_value) 729 else: 730 expression = expression.when(column == old_value, new_value) # type: ignore 731 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 732 column.expression.alias_or_name 733 ) 734 735 replacement_mapping = {**all_column_mapping, **replacement_mapping} 736 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 737 new_df = new_df.select(*replacement_columns) 738 return new_df 739 740 @operation(Operation.SELECT) 741 def withColumn(self, colName: str, col: Column) -> DataFrame: 742 col = self._ensure_and_normalize_col(col) 743 existing_col_names = self.expression.named_selects 744 existing_col_index = ( 745 existing_col_names.index(colName) if colName in existing_col_names else None 746 ) 747 if existing_col_index: 748 expression = self.expression.copy() 749 expression.expressions[existing_col_index] = col.expression 750 return self.copy(expression=expression) 751 return self.copy().select(col.alias(colName), append=True) 752 753 @operation(Operation.SELECT) 754 def withColumnRenamed(self, existing: str, new: str): 755 expression = self.expression.copy() 756 existing_columns = [ 757 expression 758 for expression in expression.expressions 759 if expression.alias_or_name == existing 760 ] 761 if not existing_columns: 762 raise ValueError("Tried to rename a column that doesn't exist") 763 for existing_column in existing_columns: 764 if isinstance(existing_column, exp.Column): 765 existing_column.replace(exp.alias_(existing_column.copy(), new)) 766 else: 767 existing_column.set("alias", exp.to_identifier(new)) 768 return self.copy(expression=expression) 769 770 @operation(Operation.SELECT) 771 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 772 all_columns = self._get_outer_select_columns(self.expression) 773 drop_cols = self._ensure_and_normalize_cols(cols) 774 new_columns = [ 775 col 776 for col in all_columns 777 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 778 ] 779 return self.copy().select(*new_columns, append=False) 780 781 @operation(Operation.LIMIT) 782 def limit(self, num: int) -> DataFrame: 783 return self.copy(expression=self.expression.limit(num)) 784 785 @operation(Operation.NO_OP) 786 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 787 parameter_list = ensure_list(parameters) 788 parameter_columns = ( 789 self._ensure_list_of_columns(parameter_list) 790 if parameters 791 else Column.ensure_cols([self.sequence_id]) 792 ) 793 return self._hint(name, parameter_columns) 794 795 @operation(Operation.NO_OP) 796 def repartition( 797 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 798 ) -> DataFrame: 799 num_partition_cols = self._ensure_list_of_columns(numPartitions) 800 columns = self._ensure_and_normalize_cols(cols) 801 args = num_partition_cols + columns 802 return self._hint("repartition", args) 803 804 @operation(Operation.NO_OP) 805 def coalesce(self, numPartitions: int) -> DataFrame: 806 num_partitions = Column.ensure_cols([numPartitions]) 807 return self._hint("coalesce", num_partitions) 808 809 @operation(Operation.NO_OP) 810 def cache(self) -> DataFrame: 811 return self._cache(storage_level="MEMORY_AND_DISK") 812 813 @operation(Operation.NO_OP) 814 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 815 """ 816 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 817 """ 818 return self._cache(storageLevel)
46 def __init__( 47 self, 48 spark: SparkSession, 49 expression: exp.Select, 50 branch_id: t.Optional[str] = None, 51 sequence_id: t.Optional[str] = None, 52 last_op: Operation = Operation.INIT, 53 pending_hints: t.Optional[t.List[exp.Expression]] = None, 54 output_expression_container: t.Optional[OutputExpressionContainer] = None, 55 **kwargs, 56 ): 57 self.spark = spark 58 self.expression = expression 59 self.branch_id = branch_id or self.spark._random_branch_id 60 self.sequence_id = sequence_id or self.spark._random_sequence_id 61 self.last_op = last_op 62 self.pending_hints = pending_hints or [] 63 self.output_expression_container = output_expression_container or exp.Select()
295 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 296 df = self._resolve_pending_hints() 297 select_expressions = df._get_select_expressions() 298 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 299 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 300 for expression_type, select_expression in select_expressions: 301 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 302 if optimize: 303 select_expression = optimize_func(select_expression, identify="always") 304 select_expression = df._replace_cte_names_with_hashes(select_expression) 305 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 306 if expression_type == exp.Cache: 307 cache_table_name = df._create_hash_from_expression(select_expression) 308 cache_table = exp.to_table(cache_table_name) 309 original_alias_name = select_expression.args["cte_alias_name"] 310 311 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 312 cache_table_name 313 ) 314 sqlglot.schema.add_table( 315 cache_table_name, 316 { 317 expression.alias_or_name: expression.type.sql("spark") 318 for expression in select_expression.expressions 319 }, 320 ) 321 cache_storage_level = select_expression.args["cache_storage_level"] 322 options = [ 323 exp.Literal.string("storageLevel"), 324 exp.Literal.string(cache_storage_level), 325 ] 326 expression = exp.Cache( 327 this=cache_table, expression=select_expression, lazy=True, options=options 328 ) 329 # We will drop the "view" if it exists before running the cache table 330 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 331 elif expression_type == exp.Create: 332 expression = df.output_expression_container.copy() 333 expression.set("expression", select_expression) 334 elif expression_type == exp.Insert: 335 expression = df.output_expression_container.copy() 336 select_without_ctes = select_expression.copy() 337 select_without_ctes.set("with", None) 338 expression.set("expression", select_without_ctes) 339 if select_expression.ctes: 340 expression.set("with", exp.With(expressions=select_expression.ctes)) 341 elif expression_type == exp.Select: 342 expression = select_expression 343 else: 344 raise ValueError(f"Invalid expression type: {expression_type}") 345 output_expressions.append(expression) 346 347 return [ 348 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 349 ]
354 @operation(Operation.SELECT) 355 def select(self, *cols, **kwargs) -> DataFrame: 356 cols = self._ensure_and_normalize_cols(cols) 357 kwargs["append"] = kwargs.get("append", False) 358 if self.expression.args.get("joins"): 359 ambiguous_cols = [ 360 col 361 for col in cols 362 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 363 ] 364 if ambiguous_cols: 365 join_table_identifiers = [ 366 x.this for x in get_tables_from_expression_with_join(self.expression) 367 ] 368 cte_names_in_join = [x.this for x in join_table_identifiers] 369 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 370 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 371 # of Spark. 372 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 373 for ambiguous_col in ambiguous_cols: 374 ctes_with_column = [ 375 cte 376 for cte in self.expression.ctes 377 if cte.alias_or_name in cte_names_in_join 378 and ambiguous_col.alias_or_name in cte.this.named_selects 379 ] 380 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 381 # use the same CTE we used before 382 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 383 if cte: 384 resolved_column_position[ambiguous_col] += 1 385 else: 386 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 387 ambiguous_col.expression.set("table", cte.alias_or_name) 388 return self.copy( 389 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 390 )
392 @operation(Operation.NO_OP) 393 def alias(self, name: str, **kwargs) -> DataFrame: 394 new_sequence_id = self.spark._random_sequence_id 395 df = self.copy() 396 for join_hint in df.pending_join_hints: 397 for expression in join_hint.expressions: 398 if expression.alias_or_name == self.sequence_id: 399 expression.set("this", Column.ensure_col(new_sequence_id).expression) 400 df.spark._add_alias_to_mapping(name, new_sequence_id) 401 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
420 @operation(Operation.FROM) 421 def join( 422 self, 423 other_df: DataFrame, 424 on: t.Union[str, t.List[str], Column, t.List[Column]], 425 how: str = "inner", 426 **kwargs, 427 ) -> DataFrame: 428 other_df = other_df._convert_leaf_to_cte() 429 join_columns = self._ensure_list_of_columns(on) 430 # We will determine actual "join on" expression later so we don't provide it at first 431 join_expression = self.expression.join( 432 other_df.latest_cte_name, join_type=how.replace("_", " ") 433 ) 434 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 435 self_columns = self._get_outer_select_columns(join_expression) 436 other_columns = self._get_outer_select_columns(other_df) 437 # Determines the join clause and select columns to be used passed on what type of columns were provided for 438 # the join. The columns returned changes based on how the on expression is provided. 439 if isinstance(join_columns[0].expression, exp.Column): 440 """ 441 Unique characteristics of join on column names only: 442 * The column names are put at the front of the select list 443 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 444 """ 445 table_names = [ 446 table.alias_or_name 447 for table in get_tables_from_expression_with_join(join_expression) 448 ] 449 potential_ctes = [ 450 cte 451 for cte in join_expression.ctes 452 if cte.alias_or_name in table_names 453 and cte.alias_or_name != other_df.latest_cte_name 454 ] 455 # Determine the table to reference for the left side of the join by checking each of the left side 456 # tables and see if they have the column being referenced. 457 join_column_pairs = [] 458 for join_column in join_columns: 459 num_matching_ctes = 0 460 for cte in potential_ctes: 461 if join_column.alias_or_name in cte.this.named_selects: 462 left_column = join_column.copy().set_table_name(cte.alias_or_name) 463 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 464 join_column_pairs.append((left_column, right_column)) 465 num_matching_ctes += 1 466 if num_matching_ctes > 1: 467 raise ValueError( 468 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 469 ) 470 elif num_matching_ctes == 0: 471 raise ValueError( 472 f"Column {join_column.alias_or_name} does not exist in any of the tables." 473 ) 474 join_clause = functools.reduce( 475 lambda x, y: x & y, 476 [left_column == right_column for left_column, right_column in join_column_pairs], 477 ) 478 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 479 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 480 select_column_names = [ 481 column.alias_or_name 482 if not isinstance(column.expression.this, exp.Star) 483 else column.sql() 484 for column in self_columns + other_columns 485 ] 486 select_column_names = [ 487 column_name 488 for column_name in select_column_names 489 if column_name not in join_column_names 490 ] 491 select_column_names = join_column_names + select_column_names 492 else: 493 """ 494 Unique characteristics of join on expressions: 495 * There is no deduplication of the results. 496 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 497 """ 498 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 499 if len(join_columns) > 1: 500 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 501 join_clause = join_columns[0] 502 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 503 504 # Update the on expression with the actual join clause to replace the dummy one from before 505 join_expression.args["joins"][-1].set("on", join_clause.expression) 506 new_df = self.copy(expression=join_expression) 507 new_df.pending_join_hints.extend(self.pending_join_hints) 508 new_df.pending_hints.extend(other_df.pending_hints) 509 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 510 return new_df
512 @operation(Operation.ORDER_BY) 513 def orderBy( 514 self, 515 *cols: t.Union[str, Column], 516 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 517 ) -> DataFrame: 518 """ 519 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 520 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 521 is unlikely to come up. 522 """ 523 columns = self._ensure_and_normalize_cols(cols) 524 pre_ordered_col_indexes = [ 525 x 526 for x in [ 527 i if isinstance(col.expression, exp.Ordered) else None 528 for i, col in enumerate(columns) 529 ] 530 if x is not None 531 ] 532 if ascending is None: 533 ascending = [True] * len(columns) 534 elif not isinstance(ascending, list): 535 ascending = [ascending] * len(columns) 536 ascending = [bool(x) for i, x in enumerate(ascending)] 537 assert len(columns) == len( 538 ascending 539 ), "The length of items in ascending must equal the number of columns provided" 540 col_and_ascending = list(zip(columns, ascending)) 541 order_by_columns = [ 542 exp.Ordered(this=col.expression, desc=not asc) 543 if i not in pre_ordered_col_indexes 544 else columns[i].column_expression 545 for i, (col, asc) in enumerate(col_and_ascending) 546 ] 547 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
512 @operation(Operation.ORDER_BY) 513 def orderBy( 514 self, 515 *cols: t.Union[str, Column], 516 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 517 ) -> DataFrame: 518 """ 519 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 520 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 521 is unlikely to come up. 522 """ 523 columns = self._ensure_and_normalize_cols(cols) 524 pre_ordered_col_indexes = [ 525 x 526 for x in [ 527 i if isinstance(col.expression, exp.Ordered) else None 528 for i, col in enumerate(columns) 529 ] 530 if x is not None 531 ] 532 if ascending is None: 533 ascending = [True] * len(columns) 534 elif not isinstance(ascending, list): 535 ascending = [ascending] * len(columns) 536 ascending = [bool(x) for i, x in enumerate(ascending)] 537 assert len(columns) == len( 538 ascending 539 ), "The length of items in ascending must equal the number of columns provided" 540 col_and_ascending = list(zip(columns, ascending)) 541 order_by_columns = [ 542 exp.Ordered(this=col.expression, desc=not asc) 543 if i not in pre_ordered_col_indexes 544 else columns[i].column_expression 545 for i, (col, asc) in enumerate(col_and_ascending) 546 ] 547 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
557 @operation(Operation.FROM) 558 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 559 l_columns = self.columns 560 r_columns = other.columns 561 if not allowMissingColumns: 562 l_expressions = l_columns 563 r_expressions = l_columns 564 else: 565 l_expressions = [] 566 r_expressions = [] 567 r_columns_unused = copy(r_columns) 568 for l_column in l_columns: 569 l_expressions.append(l_column) 570 if l_column in r_columns: 571 r_expressions.append(l_column) 572 r_columns_unused.remove(l_column) 573 else: 574 r_expressions.append(exp.alias_(exp.Null(), l_column)) 575 for r_column in r_columns_unused: 576 l_expressions.append(exp.alias_(exp.Null(), r_column)) 577 r_expressions.append(r_column) 578 r_df = ( 579 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 580 ) 581 l_df = self.copy() 582 if allowMissingColumns: 583 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 584 return l_df._set_operation(exp.Union, r_df, False)
602 @operation(Operation.SELECT) 603 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 604 if not subset: 605 return self.distinct() 606 column_names = ensure_list(subset) 607 window = Window.partitionBy(*column_names).orderBy(*column_names) 608 return ( 609 self.copy() 610 .withColumn("row_num", F.row_number().over(window)) 611 .where(F.col("row_num") == F.lit(1)) 612 .drop("row_num") 613 )
615 @operation(Operation.FROM) 616 def dropna( 617 self, 618 how: str = "any", 619 thresh: t.Optional[int] = None, 620 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 621 ) -> DataFrame: 622 minimum_non_null = thresh or 0 # will be determined later if thresh is null 623 new_df = self.copy() 624 all_columns = self._get_outer_select_columns(new_df.expression) 625 if subset: 626 null_check_columns = self._ensure_and_normalize_cols(subset) 627 else: 628 null_check_columns = all_columns 629 if thresh is None: 630 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 631 else: 632 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 633 if minimum_num_nulls > len(null_check_columns): 634 raise RuntimeError( 635 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 636 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 637 ) 638 if_null_checks = [ 639 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 640 ] 641 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 642 num_nulls = nulls_added_together.alias("num_nulls") 643 new_df = new_df.select(num_nulls, append=True) 644 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 645 final_df = filtered_df.select(*all_columns) 646 return final_df
648 @operation(Operation.FROM) 649 def fillna( 650 self, 651 value: t.Union[ColumnLiterals], 652 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 653 ) -> DataFrame: 654 """ 655 Functionality Difference: If you provide a value to replace a null and that type conflicts 656 with the type of the column then PySpark will just ignore your replacement. 657 This will try to cast them to be the same in some cases. So they won't always match. 658 Best to not mix types so make sure replacement is the same type as the column 659 660 Possibility for improvement: Use `typeof` function to get the type of the column 661 and check if it matches the type of the value provided. If not then make it null. 662 """ 663 from sqlglot.dataframe.sql.functions import lit 664 665 values = None 666 columns = None 667 new_df = self.copy() 668 all_columns = self._get_outer_select_columns(new_df.expression) 669 all_column_mapping = {column.alias_or_name: column for column in all_columns} 670 if isinstance(value, dict): 671 values = list(value.values()) 672 columns = self._ensure_and_normalize_cols(list(value)) 673 if not columns: 674 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 675 if not values: 676 values = [value] * len(columns) 677 value_columns = [lit(value) for value in values] 678 679 null_replacement_mapping = { 680 column.alias_or_name: ( 681 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 682 ) 683 for column, value in zip(columns, value_columns) 684 } 685 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 686 null_replacement_columns = [ 687 null_replacement_mapping[column.alias_or_name] for column in all_columns 688 ] 689 new_df = new_df.select(*null_replacement_columns) 690 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
692 @operation(Operation.FROM) 693 def replace( 694 self, 695 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 696 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 697 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 698 ) -> DataFrame: 699 from sqlglot.dataframe.sql.functions import lit 700 701 old_values = None 702 new_df = self.copy() 703 all_columns = self._get_outer_select_columns(new_df.expression) 704 all_column_mapping = {column.alias_or_name: column for column in all_columns} 705 706 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 707 if isinstance(to_replace, dict): 708 old_values = list(to_replace) 709 new_values = list(to_replace.values()) 710 elif not old_values and isinstance(to_replace, list): 711 assert isinstance(value, list), "value must be a list since the replacements are a list" 712 assert len(to_replace) == len( 713 value 714 ), "the replacements and values must be the same length" 715 old_values = to_replace 716 new_values = value 717 else: 718 old_values = [to_replace] * len(columns) 719 new_values = [value] * len(columns) 720 old_values = [lit(value) for value in old_values] 721 new_values = [lit(value) for value in new_values] 722 723 replacement_mapping = {} 724 for column in columns: 725 expression = Column(None) 726 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 727 if i == 0: 728 expression = F.when(column == old_value, new_value) 729 else: 730 expression = expression.when(column == old_value, new_value) # type: ignore 731 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 732 column.expression.alias_or_name 733 ) 734 735 replacement_mapping = {**all_column_mapping, **replacement_mapping} 736 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 737 new_df = new_df.select(*replacement_columns) 738 return new_df
740 @operation(Operation.SELECT) 741 def withColumn(self, colName: str, col: Column) -> DataFrame: 742 col = self._ensure_and_normalize_col(col) 743 existing_col_names = self.expression.named_selects 744 existing_col_index = ( 745 existing_col_names.index(colName) if colName in existing_col_names else None 746 ) 747 if existing_col_index: 748 expression = self.expression.copy() 749 expression.expressions[existing_col_index] = col.expression 750 return self.copy(expression=expression) 751 return self.copy().select(col.alias(colName), append=True)
753 @operation(Operation.SELECT) 754 def withColumnRenamed(self, existing: str, new: str): 755 expression = self.expression.copy() 756 existing_columns = [ 757 expression 758 for expression in expression.expressions 759 if expression.alias_or_name == existing 760 ] 761 if not existing_columns: 762 raise ValueError("Tried to rename a column that doesn't exist") 763 for existing_column in existing_columns: 764 if isinstance(existing_column, exp.Column): 765 existing_column.replace(exp.alias_(existing_column.copy(), new)) 766 else: 767 existing_column.set("alias", exp.to_identifier(new)) 768 return self.copy(expression=expression)
770 @operation(Operation.SELECT) 771 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 772 all_columns = self._get_outer_select_columns(self.expression) 773 drop_cols = self._ensure_and_normalize_cols(cols) 774 new_columns = [ 775 col 776 for col in all_columns 777 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 778 ] 779 return self.copy().select(*new_columns, append=False)
785 @operation(Operation.NO_OP) 786 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 787 parameter_list = ensure_list(parameters) 788 parameter_columns = ( 789 self._ensure_list_of_columns(parameter_list) 790 if parameters 791 else Column.ensure_cols([self.sequence_id]) 792 ) 793 return self._hint(name, parameter_columns)
795 @operation(Operation.NO_OP) 796 def repartition( 797 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 798 ) -> DataFrame: 799 num_partition_cols = self._ensure_list_of_columns(numPartitions) 800 columns = self._ensure_and_normalize_cols(cols) 801 args = num_partition_cols + columns 802 return self._hint("repartition", args)
813 @operation(Operation.NO_OP) 814 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 815 """ 816 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 817 """ 818 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 if isinstance(expression, Column): 19 expression = expression.expression # type: ignore 20 elif expression is None or not isinstance(expression, (str, exp.Expression)): 21 expression = self._lit(expression).expression # type: ignore 22 23 expression = sqlglot.maybe_parse(expression, dialect="spark") 24 if expression is None: 25 raise ValueError(f"Could not parse {expression}") 26 self.expression: exp.Expression = expression 27 28 def __repr__(self): 29 return repr(self.expression) 30 31 def __hash__(self): 32 return hash(self.expression) 33 34 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 35 return self.binary_op(exp.EQ, other) 36 37 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 38 return self.binary_op(exp.NEQ, other) 39 40 def __gt__(self, other: ColumnOrLiteral) -> Column: 41 return self.binary_op(exp.GT, other) 42 43 def __ge__(self, other: ColumnOrLiteral) -> Column: 44 return self.binary_op(exp.GTE, other) 45 46 def __lt__(self, other: ColumnOrLiteral) -> Column: 47 return self.binary_op(exp.LT, other) 48 49 def __le__(self, other: ColumnOrLiteral) -> Column: 50 return self.binary_op(exp.LTE, other) 51 52 def __and__(self, other: ColumnOrLiteral) -> Column: 53 return self.binary_op(exp.And, other) 54 55 def __or__(self, other: ColumnOrLiteral) -> Column: 56 return self.binary_op(exp.Or, other) 57 58 def __mod__(self, other: ColumnOrLiteral) -> Column: 59 return self.binary_op(exp.Mod, other) 60 61 def __add__(self, other: ColumnOrLiteral) -> Column: 62 return self.binary_op(exp.Add, other) 63 64 def __sub__(self, other: ColumnOrLiteral) -> Column: 65 return self.binary_op(exp.Sub, other) 66 67 def __mul__(self, other: ColumnOrLiteral) -> Column: 68 return self.binary_op(exp.Mul, other) 69 70 def __truediv__(self, other: ColumnOrLiteral) -> Column: 71 return self.binary_op(exp.Div, other) 72 73 def __div__(self, other: ColumnOrLiteral) -> Column: 74 return self.binary_op(exp.Div, other) 75 76 def __neg__(self) -> Column: 77 return self.unary_op(exp.Neg) 78 79 def __radd__(self, other: ColumnOrLiteral) -> Column: 80 return self.inverse_binary_op(exp.Add, other) 81 82 def __rsub__(self, other: ColumnOrLiteral) -> Column: 83 return self.inverse_binary_op(exp.Sub, other) 84 85 def __rmul__(self, other: ColumnOrLiteral) -> Column: 86 return self.inverse_binary_op(exp.Mul, other) 87 88 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 89 return self.inverse_binary_op(exp.Div, other) 90 91 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 92 return self.inverse_binary_op(exp.Div, other) 93 94 def __rmod__(self, other: ColumnOrLiteral) -> Column: 95 return self.inverse_binary_op(exp.Mod, other) 96 97 def __pow__(self, power: ColumnOrLiteral, modulo=None): 98 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 99 100 def __rpow__(self, power: ColumnOrLiteral): 101 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 102 103 def __invert__(self): 104 return self.unary_op(exp.Not) 105 106 def __rand__(self, other: ColumnOrLiteral) -> Column: 107 return self.inverse_binary_op(exp.And, other) 108 109 def __ror__(self, other: ColumnOrLiteral) -> Column: 110 return self.inverse_binary_op(exp.Or, other) 111 112 @classmethod 113 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 114 return cls(value) 115 116 @classmethod 117 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 118 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 119 120 @classmethod 121 def _lit(cls, value: ColumnOrLiteral) -> Column: 122 if isinstance(value, dict): 123 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 124 return cls(exp.Struct(expressions=columns)) 125 return cls(exp.convert(value)) 126 127 @classmethod 128 def invoke_anonymous_function( 129 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 130 ) -> Column: 131 columns = [] if column is None else [cls.ensure_col(column)] 132 column_args = [cls.ensure_col(arg) for arg in args] 133 expressions = [x.expression for x in columns + column_args] 134 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 135 return Column(new_expression) 136 137 @classmethod 138 def invoke_expression_over_column( 139 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 140 ) -> Column: 141 ensured_column = None if column is None else cls.ensure_col(column) 142 ensure_expression_values = { 143 k: [Column.ensure_col(x).expression for x in v] 144 if is_iterable(v) 145 else Column.ensure_col(v).expression 146 for k, v in kwargs.items() 147 if v is not None 148 } 149 new_expression = ( 150 callable_expression(**ensure_expression_values) 151 if ensured_column is None 152 else callable_expression( 153 this=ensured_column.column_expression, **ensure_expression_values 154 ) 155 ) 156 return Column(new_expression) 157 158 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 159 return Column( 160 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 161 ) 162 163 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 164 return Column( 165 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 166 ) 167 168 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 169 return Column(klass(this=self.column_expression, **kwargs)) 170 171 @property 172 def is_alias(self): 173 return isinstance(self.expression, exp.Alias) 174 175 @property 176 def is_column(self): 177 return isinstance(self.expression, exp.Column) 178 179 @property 180 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 181 return self.expression.unalias() 182 183 @property 184 def alias_or_name(self) -> str: 185 return self.expression.alias_or_name 186 187 @classmethod 188 def ensure_literal(cls, value) -> Column: 189 from sqlglot.dataframe.sql.functions import lit 190 191 if isinstance(value, cls): 192 value = value.expression 193 if not isinstance(value, exp.Literal): 194 return lit(value) 195 return Column(value) 196 197 def copy(self) -> Column: 198 return Column(self.expression.copy()) 199 200 def set_table_name(self, table_name: str, copy=False) -> Column: 201 expression = self.expression.copy() if copy else self.expression 202 expression.set("table", exp.to_identifier(table_name)) 203 return Column(expression) 204 205 def sql(self, **kwargs) -> str: 206 return self.expression.sql(**{"dialect": "spark", **kwargs}) 207 208 def alias(self, name: str) -> Column: 209 new_expression = exp.alias_(self.column_expression, name) 210 return Column(new_expression) 211 212 def asc(self) -> Column: 213 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 214 return Column(new_expression) 215 216 def desc(self) -> Column: 217 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 218 return Column(new_expression) 219 220 asc_nulls_first = asc 221 222 def asc_nulls_last(self) -> Column: 223 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 224 return Column(new_expression) 225 226 def desc_nulls_first(self) -> Column: 227 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 228 return Column(new_expression) 229 230 desc_nulls_last = desc 231 232 def when(self, condition: Column, value: t.Any) -> Column: 233 from sqlglot.dataframe.sql.functions import when 234 235 column_with_if = when(condition, value) 236 if not isinstance(self.expression, exp.Case): 237 return column_with_if 238 new_column = self.copy() 239 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 240 return new_column 241 242 def otherwise(self, value: t.Any) -> Column: 243 from sqlglot.dataframe.sql.functions import lit 244 245 true_value = value if isinstance(value, Column) else lit(value) 246 new_column = self.copy() 247 new_column.expression.set("default", true_value.column_expression) 248 return new_column 249 250 def isNull(self) -> Column: 251 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 252 return Column(new_expression) 253 254 def isNotNull(self) -> Column: 255 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 256 return Column(new_expression) 257 258 def cast(self, dataType: t.Union[str, DataType]): 259 """ 260 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 261 Sqlglot doesn't currently replicate this class so it only accepts a string 262 """ 263 if isinstance(dataType, DataType): 264 dataType = dataType.simpleString() 265 return Column(exp.cast(self.column_expression, dataType, dialect="spark")) 266 267 def startswith(self, value: t.Union[str, Column]) -> Column: 268 value = self._lit(value) if not isinstance(value, Column) else value 269 return self.invoke_anonymous_function(self, "STARTSWITH", value) 270 271 def endswith(self, value: t.Union[str, Column]) -> Column: 272 value = self._lit(value) if not isinstance(value, Column) else value 273 return self.invoke_anonymous_function(self, "ENDSWITH", value) 274 275 def rlike(self, regexp: str) -> Column: 276 return self.invoke_expression_over_column( 277 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 278 ) 279 280 def like(self, other: str): 281 return self.invoke_expression_over_column( 282 self, exp.Like, expression=self._lit(other).expression 283 ) 284 285 def ilike(self, other: str): 286 return self.invoke_expression_over_column( 287 self, exp.ILike, expression=self._lit(other).expression 288 ) 289 290 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 291 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 292 length = self._lit(length) if not isinstance(length, Column) else length 293 return Column.invoke_expression_over_column( 294 self, exp.Substring, start=startPos.expression, length=length.expression 295 ) 296 297 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 298 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 299 expressions = [self._lit(x).expression for x in columns] 300 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 301 302 def between( 303 self, 304 lowerBound: t.Union[ColumnOrLiteral], 305 upperBound: t.Union[ColumnOrLiteral], 306 ) -> Column: 307 lower_bound_exp = ( 308 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 309 ) 310 upper_bound_exp = ( 311 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 312 ) 313 return Column( 314 exp.Between( 315 this=self.column_expression, 316 low=lower_bound_exp.expression, 317 high=upper_bound_exp.expression, 318 ) 319 ) 320 321 def over(self, window: WindowSpec) -> Column: 322 window_expression = window.expression.copy() 323 window_expression.set("this", self.column_expression) 324 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 if isinstance(expression, Column): 19 expression = expression.expression # type: ignore 20 elif expression is None or not isinstance(expression, (str, exp.Expression)): 21 expression = self._lit(expression).expression # type: ignore 22 23 expression = sqlglot.maybe_parse(expression, dialect="spark") 24 if expression is None: 25 raise ValueError(f"Could not parse {expression}") 26 self.expression: exp.Expression = expression
127 @classmethod 128 def invoke_anonymous_function( 129 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 130 ) -> Column: 131 columns = [] if column is None else [cls.ensure_col(column)] 132 column_args = [cls.ensure_col(arg) for arg in args] 133 expressions = [x.expression for x in columns + column_args] 134 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 135 return Column(new_expression)
137 @classmethod 138 def invoke_expression_over_column( 139 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 140 ) -> Column: 141 ensured_column = None if column is None else cls.ensure_col(column) 142 ensure_expression_values = { 143 k: [Column.ensure_col(x).expression for x in v] 144 if is_iterable(v) 145 else Column.ensure_col(v).expression 146 for k, v in kwargs.items() 147 if v is not None 148 } 149 new_expression = ( 150 callable_expression(**ensure_expression_values) 151 if ensured_column is None 152 else callable_expression( 153 this=ensured_column.column_expression, **ensure_expression_values 154 ) 155 ) 156 return Column(new_expression)
232 def when(self, condition: Column, value: t.Any) -> Column: 233 from sqlglot.dataframe.sql.functions import when 234 235 column_with_if = when(condition, value) 236 if not isinstance(self.expression, exp.Case): 237 return column_with_if 238 new_column = self.copy() 239 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 240 return new_column
258 def cast(self, dataType: t.Union[str, DataType]): 259 """ 260 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 261 Sqlglot doesn't currently replicate this class so it only accepts a string 262 """ 263 if isinstance(dataType, DataType): 264 dataType = dataType.simpleString() 265 return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
290 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 291 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 292 length = self._lit(length) if not isinstance(length, Column) else length 293 return Column.invoke_expression_over_column( 294 self, exp.Substring, start=startPos.expression, length=length.expression 295 )
297 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 298 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 299 expressions = [self._lit(x).expression for x in columns] 300 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
302 def between( 303 self, 304 lowerBound: t.Union[ColumnOrLiteral], 305 upperBound: t.Union[ColumnOrLiteral], 306 ) -> Column: 307 lower_bound_exp = ( 308 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 309 ) 310 upper_bound_exp = ( 311 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 312 ) 313 return Column( 314 exp.Between( 315 this=self.column_expression, 316 low=lower_bound_exp.expression, 317 high=upper_bound_exp.expression, 318 ) 319 )
821class DataFrameNaFunctions: 822 def __init__(self, df: DataFrame): 823 self.df = df 824 825 def drop( 826 self, 827 how: str = "any", 828 thresh: t.Optional[int] = None, 829 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 830 ) -> DataFrame: 831 return self.df.dropna(how=how, thresh=thresh, subset=subset) 832 833 def fill( 834 self, 835 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 836 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 837 ) -> DataFrame: 838 return self.df.fillna(value=value, subset=subset) 839 840 def replace( 841 self, 842 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 843 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 844 subset: t.Optional[t.Union[str, t.List[str]]] = None, 845 ) -> DataFrame: 846 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
840 def replace( 841 self, 842 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 843 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 844 subset: t.Optional[t.Union[str, t.List[str]]] = None, 845 ) -> DataFrame: 846 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 return self.expression.sql(dialect="spark", **kwargs) 53 54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec 64 65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec 77 78 def _calc_start_end( 79 self, start: int, end: int 80 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 81 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 82 "start_side": None, 83 "end_side": None, 84 } 85 if start == Window.currentRow: 86 kwargs["start"] = "CURRENT ROW" 87 else: 88 kwargs = { 89 **kwargs, 90 **{ 91 "start_side": "PRECEDING", 92 "start": "UNBOUNDED" 93 if start <= Window.unboundedPreceding 94 else F.lit(start).expression, 95 }, 96 } 97 if end == Window.currentRow: 98 kwargs["end"] = "CURRENT ROW" 99 else: 100 kwargs = { 101 **kwargs, 102 **{ 103 "end_side": "FOLLOWING", 104 "end": "UNBOUNDED" 105 if end >= Window.unboundedFollowing 106 else F.lit(end).expression, 107 }, 108 } 109 return kwargs 110 111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec 122 123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec
65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec
111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec
123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 22 sqlglot.schema.add_table(tableName) 23 24 return DataFrame( 25 self.spark, 26 exp.Select() 27 .from_(tableName) 28 .select( 29 *( 30 column if should_identify(column, "safe") else f'"{column}"' 31 for column in sqlglot.schema.column_names(tableName) 32 ) 33 ), 34 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 22 sqlglot.schema.add_table(tableName) 23 24 return DataFrame( 25 self.spark, 26 exp.Select() 27 .from_(tableName) 28 .select( 29 *( 30 column if should_identify(column, "safe") else f'"{column}"' 31 for column in sqlglot.schema.column_names(tableName) 32 ) 33 ), 34 )
37class DataFrameWriter: 38 def __init__( 39 self, 40 df: DataFrame, 41 spark: t.Optional[SparkSession] = None, 42 mode: t.Optional[str] = None, 43 by_name: bool = False, 44 ): 45 self._df = df 46 self._spark = spark or df.spark 47 self._mode = mode 48 self._by_name = by_name 49 50 def copy(self, **kwargs) -> DataFrameWriter: 51 return DataFrameWriter( 52 **{ 53 k[1:] if k.startswith("_") else k: v 54 for k, v in object_to_dict(self, **kwargs).items() 55 } 56 ) 57 58 def sql(self, **kwargs) -> t.List[str]: 59 return self._df.sql(**kwargs) 60 61 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 62 return self.copy(_mode=saveMode) 63 64 @property 65 def byName(self): 66 return self.copy(by_name=True) 67 68 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 69 output_expression_container = exp.Insert( 70 **{ 71 "this": exp.to_table(tableName), 72 "overwrite": overwrite, 73 } 74 ) 75 df = self._df.copy(output_expression_container=output_expression_container) 76 if self._by_name: 77 columns = sqlglot.schema.column_names(tableName, only_visible=True) 78 df = df._convert_leaf_to_cte().select(*columns) 79 80 return self.copy(_df=df) 81 82 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 83 if format is not None: 84 raise NotImplementedError("Providing Format in the save as table is not supported") 85 exists, replace, mode = None, None, mode or str(self._mode) 86 if mode == "append": 87 return self.insertInto(name) 88 if mode == "ignore": 89 exists = True 90 if mode == "overwrite": 91 replace = True 92 output_expression_container = exp.Create( 93 this=exp.to_table(name), 94 kind="TABLE", 95 exists=exists, 96 replace=replace, 97 ) 98 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
68 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 69 output_expression_container = exp.Insert( 70 **{ 71 "this": exp.to_table(tableName), 72 "overwrite": overwrite, 73 } 74 ) 75 df = self._df.copy(output_expression_container=output_expression_container) 76 if self._by_name: 77 columns = sqlglot.schema.column_names(tableName, only_visible=True) 78 df = df._convert_leaf_to_cte().select(*columns) 79 80 return self.copy(_df=df)
82 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 83 if format is not None: 84 raise NotImplementedError("Providing Format in the save as table is not supported") 85 exists, replace, mode = None, None, mode or str(self._mode) 86 if mode == "append": 87 return self.insertInto(name) 88 if mode == "ignore": 89 exists = True 90 if mode == "overwrite": 91 replace = True 92 output_expression_container = exp.Create( 93 this=exp.to_table(name), 94 kind="TABLE", 95 exists=exists, 96 replace=replace, 97 ) 98 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))