sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
20class SparkSession: 21 known_ids: t.ClassVar[t.Set[str]] = set() 22 known_branch_ids: t.ClassVar[t.Set[str]] = set() 23 known_sequence_ids: t.ClassVar[t.Set[str]] = set() 24 name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) 25 26 def __init__(self): 27 self.incrementing_id = 1 28 29 def __getattr__(self, name: str) -> SparkSession: 30 return self 31 32 def __call__(self, *args, **kwargs) -> SparkSession: 33 return self 34 35 @property 36 def read(self) -> DataFrameReader: 37 return DataFrameReader(self) 38 39 def table(self, tableName: str) -> DataFrame: 40 return self.read.table(tableName) 41 42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 expressions=[ 92 exp.Values( 93 expressions=data_expressions, 94 alias=exp.TableAlias( 95 this=exp.to_identifier(self._auto_incrementing_name), 96 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 97 ), 98 ), 99 ], 100 ), 101 } 102 103 sel_expression = exp.Select(**select_kwargs) 104 return DataFrame(self, sel_expression) 105 106 def sql(self, sqlQuery: str) -> DataFrame: 107 expression = sqlglot.parse_one(sqlQuery, read="spark") 108 if isinstance(expression, exp.Select): 109 df = DataFrame(self, expression) 110 df = df._convert_leaf_to_cte() 111 elif isinstance(expression, (exp.Create, exp.Insert)): 112 select_expression = expression.expression.copy() 113 if isinstance(expression, exp.Insert): 114 select_expression.set("with", expression.args.get("with")) 115 expression.set("with", None) 116 del expression.args["expression"] 117 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 118 df = df._convert_leaf_to_cte() 119 else: 120 raise ValueError( 121 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 122 ) 123 return df 124 125 @property 126 def _auto_incrementing_name(self) -> str: 127 name = f"a{self.incrementing_id}" 128 self.incrementing_id += 1 129 return name 130 131 @property 132 def _random_name(self) -> str: 133 return "r" + uuid.uuid4().hex 134 135 @property 136 def _random_branch_id(self) -> str: 137 id = self._random_id 138 self.known_branch_ids.add(id) 139 return id 140 141 @property 142 def _random_sequence_id(self): 143 id = self._random_id 144 self.known_sequence_ids.add(id) 145 return id 146 147 @property 148 def _random_id(self) -> str: 149 id = self._random_name 150 self.known_ids.add(id) 151 return id 152 153 @property 154 def _join_hint_names(self) -> t.Set[str]: 155 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 156 157 def _add_alias_to_mapping(self, name: str, sequence_id: str): 158 self.name_to_sequence_id_mapping[name].append(sequence_id)
42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 expressions=[ 92 exp.Values( 93 expressions=data_expressions, 94 alias=exp.TableAlias( 95 this=exp.to_identifier(self._auto_incrementing_name), 96 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 97 ), 98 ), 99 ], 100 ), 101 } 102 103 sel_expression = exp.Select(**select_kwargs) 104 return DataFrame(self, sel_expression)
106 def sql(self, sqlQuery: str) -> DataFrame: 107 expression = sqlglot.parse_one(sqlQuery, read="spark") 108 if isinstance(expression, exp.Select): 109 df = DataFrame(self, expression) 110 df = df._convert_leaf_to_cte() 111 elif isinstance(expression, (exp.Create, exp.Insert)): 112 select_expression = expression.expression.copy() 113 if isinstance(expression, exp.Insert): 114 select_expression.set("with", expression.args.get("with")) 115 expression.set("with", None) 116 del expression.args["expression"] 117 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 118 df = df._convert_leaf_to_cte() 119 else: 120 raise ValueError( 121 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 122 ) 123 return df
45class DataFrame: 46 def __init__( 47 self, 48 spark: SparkSession, 49 expression: exp.Select, 50 branch_id: t.Optional[str] = None, 51 sequence_id: t.Optional[str] = None, 52 last_op: Operation = Operation.INIT, 53 pending_hints: t.Optional[t.List[exp.Expression]] = None, 54 output_expression_container: t.Optional[OutputExpressionContainer] = None, 55 **kwargs, 56 ): 57 self.spark = spark 58 self.expression = expression 59 self.branch_id = branch_id or self.spark._random_branch_id 60 self.sequence_id = sequence_id or self.spark._random_sequence_id 61 self.last_op = last_op 62 self.pending_hints = pending_hints or [] 63 self.output_expression_container = output_expression_container or exp.Select() 64 65 def __getattr__(self, column_name: str) -> Column: 66 return self[column_name] 67 68 def __getitem__(self, column_name: str) -> Column: 69 column_name = f"{self.branch_id}.{column_name}" 70 return Column(column_name) 71 72 def __copy__(self): 73 return self.copy() 74 75 @property 76 def sparkSession(self): 77 return self.spark 78 79 @property 80 def write(self): 81 return DataFrameWriter(self) 82 83 @property 84 def latest_cte_name(self) -> str: 85 if not self.expression.ctes: 86 from_exp = self.expression.args["from"] 87 if from_exp.alias_or_name: 88 return from_exp.alias_or_name 89 table_alias = from_exp.find(exp.TableAlias) 90 if not table_alias: 91 raise RuntimeError( 92 f"Could not find an alias name for this expression: {self.expression}" 93 ) 94 return table_alias.alias_or_name 95 return self.expression.ctes[-1].alias 96 97 @property 98 def pending_join_hints(self): 99 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 100 101 @property 102 def pending_partition_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 104 105 @property 106 def columns(self) -> t.List[str]: 107 return self.expression.named_selects 108 109 @property 110 def na(self) -> DataFrameNaFunctions: 111 return DataFrameNaFunctions(self) 112 113 def _replace_cte_names_with_hashes(self, expression: exp.Select): 114 replacement_mapping = {} 115 for cte in expression.ctes: 116 old_name_id = cte.args["alias"].this 117 new_hashed_id = exp.to_identifier( 118 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 119 ) 120 replacement_mapping[old_name_id] = new_hashed_id 121 expression = expression.transform(replace_id_value, replacement_mapping) 122 return expression 123 124 def _create_cte_from_expression( 125 self, 126 expression: exp.Expression, 127 branch_id: t.Optional[str] = None, 128 sequence_id: t.Optional[str] = None, 129 **kwargs, 130 ) -> t.Tuple[exp.CTE, str]: 131 name = self.spark._random_name 132 expression_to_cte = expression.copy() 133 expression_to_cte.set("with", None) 134 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 135 cte.set("branch_id", branch_id or self.branch_id) 136 cte.set("sequence_id", sequence_id or self.sequence_id) 137 return cte, name 138 139 @t.overload 140 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 141 ... 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 145 ... 146 147 def _ensure_list_of_columns(self, cols): 148 return Column.ensure_cols(ensure_list(cols)) 149 150 def _ensure_and_normalize_cols(self, cols): 151 cols = self._ensure_list_of_columns(cols) 152 normalize(self.spark, self.expression, cols) 153 return cols 154 155 def _ensure_and_normalize_col(self, col): 156 col = Column.ensure_col(col) 157 normalize(self.spark, self.expression, col) 158 return col 159 160 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 161 df = self._resolve_pending_hints() 162 sequence_id = sequence_id or df.sequence_id 163 expression = df.expression.copy() 164 cte_expression, cte_name = df._create_cte_from_expression( 165 expression=expression, sequence_id=sequence_id 166 ) 167 new_expression = df._add_ctes_to_expression( 168 exp.Select(), expression.ctes + [cte_expression] 169 ) 170 sel_columns = df._get_outer_select_columns(cte_expression) 171 new_expression = new_expression.from_(cte_name).select( 172 *[x.alias_or_name for x in sel_columns] 173 ) 174 return df.copy(expression=new_expression, sequence_id=sequence_id) 175 176 def _resolve_pending_hints(self) -> DataFrame: 177 df = self.copy() 178 if not self.pending_hints: 179 return df 180 expression = df.expression 181 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 182 for hint in df.pending_partition_hints: 183 hint_expression.append("expressions", hint) 184 df.pending_hints.remove(hint) 185 186 join_aliases = { 187 join_table.alias_or_name 188 for join_table in get_tables_from_expression_with_join(expression) 189 } 190 if join_aliases: 191 for hint in df.pending_join_hints: 192 for sequence_id_expression in hint.expressions: 193 sequence_id_or_name = sequence_id_expression.alias_or_name 194 sequence_ids_to_match = [sequence_id_or_name] 195 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 196 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 197 sequence_id_or_name 198 ] 199 matching_ctes = [ 200 cte 201 for cte in reversed(expression.ctes) 202 if cte.args["sequence_id"] in sequence_ids_to_match 203 ] 204 for matching_cte in matching_ctes: 205 if matching_cte.alias_or_name in join_aliases: 206 sequence_id_expression.set("this", matching_cte.args["alias"].this) 207 df.pending_hints.remove(hint) 208 break 209 hint_expression.append("expressions", hint) 210 if hint_expression.expressions: 211 expression.set("hint", hint_expression) 212 return df 213 214 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 215 hint_name = hint_name.upper() 216 hint_expression = ( 217 exp.JoinHint( 218 this=hint_name, 219 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 220 ) 221 if hint_name in JOIN_HINTS 222 else exp.Anonymous( 223 this=hint_name, expressions=[parameter.expression for parameter in args] 224 ) 225 ) 226 new_df = self.copy() 227 new_df.pending_hints.append(hint_expression) 228 return new_df 229 230 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 231 other_df = other._convert_leaf_to_cte() 232 base_expression = self.expression.copy() 233 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 234 all_ctes = base_expression.ctes 235 other_df.expression.set("with", None) 236 base_expression.set("with", None) 237 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 238 operation.set("with", exp.With(expressions=all_ctes)) 239 return self.copy(expression=operation)._convert_leaf_to_cte() 240 241 def _cache(self, storage_level: str): 242 df = self._convert_leaf_to_cte() 243 df.expression.ctes[-1].set("cache_storage_level", storage_level) 244 return df 245 246 @classmethod 247 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 248 expression = expression.copy() 249 with_expression = expression.args.get("with") 250 if with_expression: 251 existing_ctes = with_expression.expressions 252 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 253 for cte in ctes: 254 if cte.alias_or_name not in existsing_cte_names: 255 existing_ctes.append(cte) 256 else: 257 existing_ctes = ctes 258 expression.set("with", exp.With(expressions=existing_ctes)) 259 return expression 260 261 @classmethod 262 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 263 expression = item.expression if isinstance(item, DataFrame) else item 264 return [Column(x) for x in expression.find(exp.Select).expressions] 265 266 @classmethod 267 def _create_hash_from_expression(cls, expression: exp.Select): 268 value = expression.sql(dialect="spark").encode("utf-8") 269 return f"t{zlib.crc32(value)}"[:6] 270 271 def _get_select_expressions( 272 self, 273 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 274 select_expressions: t.List[ 275 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 276 ] = [] 277 main_select_ctes: t.List[exp.CTE] = [] 278 for cte in self.expression.ctes: 279 cache_storage_level = cte.args.get("cache_storage_level") 280 if cache_storage_level: 281 select_expression = cte.this.copy() 282 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 283 select_expression.set("cte_alias_name", cte.alias_or_name) 284 select_expression.set("cache_storage_level", cache_storage_level) 285 select_expressions.append((exp.Cache, select_expression)) 286 else: 287 main_select_ctes.append(cte) 288 main_select = self.expression.copy() 289 if main_select_ctes: 290 main_select.set("with", exp.With(expressions=main_select_ctes)) 291 expression_select_pair = (type(self.output_expression_container), main_select) 292 select_expressions.append(expression_select_pair) # type: ignore 293 return select_expressions 294 295 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 296 df = self._resolve_pending_hints() 297 select_expressions = df._get_select_expressions() 298 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 299 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 300 for expression_type, select_expression in select_expressions: 301 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 302 if optimize: 303 select_expression = optimize_func(select_expression) 304 select_expression = df._replace_cte_names_with_hashes(select_expression) 305 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 306 if expression_type == exp.Cache: 307 cache_table_name = df._create_hash_from_expression(select_expression) 308 cache_table = exp.to_table(cache_table_name) 309 original_alias_name = select_expression.args["cte_alias_name"] 310 311 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 312 cache_table_name 313 ) 314 sqlglot.schema.add_table( 315 cache_table_name, 316 { 317 expression.alias_or_name: expression.type.sql("spark") 318 for expression in select_expression.expressions 319 }, 320 ) 321 cache_storage_level = select_expression.args["cache_storage_level"] 322 options = [ 323 exp.Literal.string("storageLevel"), 324 exp.Literal.string(cache_storage_level), 325 ] 326 expression = exp.Cache( 327 this=cache_table, expression=select_expression, lazy=True, options=options 328 ) 329 # We will drop the "view" if it exists before running the cache table 330 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 331 elif expression_type == exp.Create: 332 expression = df.output_expression_container.copy() 333 expression.set("expression", select_expression) 334 elif expression_type == exp.Insert: 335 expression = df.output_expression_container.copy() 336 select_without_ctes = select_expression.copy() 337 select_without_ctes.set("with", None) 338 expression.set("expression", select_without_ctes) 339 if select_expression.ctes: 340 expression.set("with", exp.With(expressions=select_expression.ctes)) 341 elif expression_type == exp.Select: 342 expression = select_expression 343 else: 344 raise ValueError(f"Invalid expression type: {expression_type}") 345 output_expressions.append(expression) 346 347 return [ 348 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 349 ] 350 351 def copy(self, **kwargs) -> DataFrame: 352 return DataFrame(**object_to_dict(self, **kwargs)) 353 354 @operation(Operation.SELECT) 355 def select(self, *cols, **kwargs) -> DataFrame: 356 cols = self._ensure_and_normalize_cols(cols) 357 kwargs["append"] = kwargs.get("append", False) 358 if self.expression.args.get("joins"): 359 ambiguous_cols = [col for col in cols if not col.column_expression.table] 360 if ambiguous_cols: 361 join_table_identifiers = [ 362 x.this for x in get_tables_from_expression_with_join(self.expression) 363 ] 364 cte_names_in_join = [x.this for x in join_table_identifiers] 365 for ambiguous_col in ambiguous_cols: 366 ctes_with_column = [ 367 cte 368 for cte in self.expression.ctes 369 if cte.alias_or_name in cte_names_in_join 370 and ambiguous_col.alias_or_name in cte.this.named_selects 371 ] 372 # If the select column does not specify a table and there is a join 373 # then we assume they are referring to the left table 374 if len(ctes_with_column) > 1: 375 table_identifier = self.expression.args["from"].args["expressions"][0].this 376 else: 377 table_identifier = ctes_with_column[0].args["alias"].this 378 ambiguous_col.expression.set("table", table_identifier) 379 return self.copy( 380 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 381 ) 382 383 @operation(Operation.NO_OP) 384 def alias(self, name: str, **kwargs) -> DataFrame: 385 new_sequence_id = self.spark._random_sequence_id 386 df = self.copy() 387 for join_hint in df.pending_join_hints: 388 for expression in join_hint.expressions: 389 if expression.alias_or_name == self.sequence_id: 390 expression.set("this", Column.ensure_col(new_sequence_id).expression) 391 df.spark._add_alias_to_mapping(name, new_sequence_id) 392 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 393 394 @operation(Operation.WHERE) 395 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 396 col = self._ensure_and_normalize_col(column) 397 return self.copy(expression=self.expression.where(col.expression)) 398 399 filter = where 400 401 @operation(Operation.GROUP_BY) 402 def groupBy(self, *cols, **kwargs) -> GroupedData: 403 columns = self._ensure_and_normalize_cols(cols) 404 return GroupedData(self, columns, self.last_op) 405 406 @operation(Operation.SELECT) 407 def agg(self, *exprs, **kwargs) -> DataFrame: 408 cols = self._ensure_and_normalize_cols(exprs) 409 return self.groupBy().agg(*cols) 410 411 @operation(Operation.FROM) 412 def join( 413 self, 414 other_df: DataFrame, 415 on: t.Union[str, t.List[str], Column, t.List[Column]], 416 how: str = "inner", 417 **kwargs, 418 ) -> DataFrame: 419 other_df = other_df._convert_leaf_to_cte() 420 pre_join_self_latest_cte_name = self.latest_cte_name 421 columns = self._ensure_and_normalize_cols(on) 422 join_type = how.replace("_", " ") 423 if isinstance(columns[0].expression, exp.Column): 424 join_columns = [ 425 Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns 426 ] 427 join_clause = functools.reduce( 428 lambda x, y: x & y, 429 [ 430 col.copy().set_table_name(pre_join_self_latest_cte_name) 431 == col.copy().set_table_name(other_df.latest_cte_name) 432 for col in columns 433 ], 434 ) 435 else: 436 if len(columns) > 1: 437 columns = [functools.reduce(lambda x, y: x & y, columns)] 438 join_clause = columns[0] 439 join_columns = [ 440 Column(x).set_table_name(pre_join_self_latest_cte_name) 441 if i % 2 == 0 442 else Column(x).set_table_name(other_df.latest_cte_name) 443 for i, x in enumerate(join_clause.expression.find_all(exp.Column)) 444 ] 445 self_columns = [ 446 column.set_table_name(pre_join_self_latest_cte_name, copy=True) 447 for column in self._get_outer_select_columns(self) 448 ] 449 other_columns = [ 450 column.set_table_name(other_df.latest_cte_name, copy=True) 451 for column in self._get_outer_select_columns(other_df) 452 ] 453 column_value_mapping = { 454 column.alias_or_name 455 if not isinstance(column.expression.this, exp.Star) 456 else column.sql(): column 457 for column in other_columns + self_columns + join_columns 458 } 459 all_columns = [ 460 column_value_mapping[name] 461 for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} 462 ] 463 new_df = self.copy( 464 expression=self.expression.join( 465 other_df.latest_cte_name, on=join_clause.expression, join_type=join_type 466 ) 467 ) 468 new_df.expression = new_df._add_ctes_to_expression( 469 new_df.expression, other_df.expression.ctes 470 ) 471 new_df.pending_hints.extend(other_df.pending_hints) 472 new_df = new_df.select.__wrapped__(new_df, *all_columns) 473 return new_df 474 475 @operation(Operation.ORDER_BY) 476 def orderBy( 477 self, 478 *cols: t.Union[str, Column], 479 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 480 ) -> DataFrame: 481 """ 482 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 483 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 484 is unlikely to come up. 485 """ 486 columns = self._ensure_and_normalize_cols(cols) 487 pre_ordered_col_indexes = [ 488 x 489 for x in [ 490 i if isinstance(col.expression, exp.Ordered) else None 491 for i, col in enumerate(columns) 492 ] 493 if x is not None 494 ] 495 if ascending is None: 496 ascending = [True] * len(columns) 497 elif not isinstance(ascending, list): 498 ascending = [ascending] * len(columns) 499 ascending = [bool(x) for i, x in enumerate(ascending)] 500 assert len(columns) == len( 501 ascending 502 ), "The length of items in ascending must equal the number of columns provided" 503 col_and_ascending = list(zip(columns, ascending)) 504 order_by_columns = [ 505 exp.Ordered(this=col.expression, desc=not asc) 506 if i not in pre_ordered_col_indexes 507 else columns[i].column_expression 508 for i, (col, asc) in enumerate(col_and_ascending) 509 ] 510 return self.copy(expression=self.expression.order_by(*order_by_columns)) 511 512 sort = orderBy 513 514 @operation(Operation.FROM) 515 def union(self, other: DataFrame) -> DataFrame: 516 return self._set_operation(exp.Union, other, False) 517 518 unionAll = union 519 520 @operation(Operation.FROM) 521 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 522 l_columns = self.columns 523 r_columns = other.columns 524 if not allowMissingColumns: 525 l_expressions = l_columns 526 r_expressions = l_columns 527 else: 528 l_expressions = [] 529 r_expressions = [] 530 r_columns_unused = copy(r_columns) 531 for l_column in l_columns: 532 l_expressions.append(l_column) 533 if l_column in r_columns: 534 r_expressions.append(l_column) 535 r_columns_unused.remove(l_column) 536 else: 537 r_expressions.append(exp.alias_(exp.Null(), l_column)) 538 for r_column in r_columns_unused: 539 l_expressions.append(exp.alias_(exp.Null(), r_column)) 540 r_expressions.append(r_column) 541 r_df = ( 542 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 543 ) 544 l_df = self.copy() 545 if allowMissingColumns: 546 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 547 return l_df._set_operation(exp.Union, r_df, False) 548 549 @operation(Operation.FROM) 550 def intersect(self, other: DataFrame) -> DataFrame: 551 return self._set_operation(exp.Intersect, other, True) 552 553 @operation(Operation.FROM) 554 def intersectAll(self, other: DataFrame) -> DataFrame: 555 return self._set_operation(exp.Intersect, other, False) 556 557 @operation(Operation.FROM) 558 def exceptAll(self, other: DataFrame) -> DataFrame: 559 return self._set_operation(exp.Except, other, False) 560 561 @operation(Operation.SELECT) 562 def distinct(self) -> DataFrame: 563 return self.copy(expression=self.expression.distinct()) 564 565 @operation(Operation.SELECT) 566 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 567 if not subset: 568 return self.distinct() 569 column_names = ensure_list(subset) 570 window = Window.partitionBy(*column_names).orderBy(*column_names) 571 return ( 572 self.copy() 573 .withColumn("row_num", F.row_number().over(window)) 574 .where(F.col("row_num") == F.lit(1)) 575 .drop("row_num") 576 ) 577 578 @operation(Operation.FROM) 579 def dropna( 580 self, 581 how: str = "any", 582 thresh: t.Optional[int] = None, 583 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 584 ) -> DataFrame: 585 minimum_non_null = thresh or 0 # will be determined later if thresh is null 586 new_df = self.copy() 587 all_columns = self._get_outer_select_columns(new_df.expression) 588 if subset: 589 null_check_columns = self._ensure_and_normalize_cols(subset) 590 else: 591 null_check_columns = all_columns 592 if thresh is None: 593 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 594 else: 595 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 596 if minimum_num_nulls > len(null_check_columns): 597 raise RuntimeError( 598 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 599 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 600 ) 601 if_null_checks = [ 602 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 603 ] 604 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 605 num_nulls = nulls_added_together.alias("num_nulls") 606 new_df = new_df.select(num_nulls, append=True) 607 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 608 final_df = filtered_df.select(*all_columns) 609 return final_df 610 611 @operation(Operation.FROM) 612 def fillna( 613 self, 614 value: t.Union[ColumnLiterals], 615 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 616 ) -> DataFrame: 617 """ 618 Functionality Difference: If you provide a value to replace a null and that type conflicts 619 with the type of the column then PySpark will just ignore your replacement. 620 This will try to cast them to be the same in some cases. So they won't always match. 621 Best to not mix types so make sure replacement is the same type as the column 622 623 Possibility for improvement: Use `typeof` function to get the type of the column 624 and check if it matches the type of the value provided. If not then make it null. 625 """ 626 from sqlglot.dataframe.sql.functions import lit 627 628 values = None 629 columns = None 630 new_df = self.copy() 631 all_columns = self._get_outer_select_columns(new_df.expression) 632 all_column_mapping = {column.alias_or_name: column for column in all_columns} 633 if isinstance(value, dict): 634 values = list(value.values()) 635 columns = self._ensure_and_normalize_cols(list(value)) 636 if not columns: 637 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 638 if not values: 639 values = [value] * len(columns) 640 value_columns = [lit(value) for value in values] 641 642 null_replacement_mapping = { 643 column.alias_or_name: ( 644 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 645 ) 646 for column, value in zip(columns, value_columns) 647 } 648 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 649 null_replacement_columns = [ 650 null_replacement_mapping[column.alias_or_name] for column in all_columns 651 ] 652 new_df = new_df.select(*null_replacement_columns) 653 return new_df 654 655 @operation(Operation.FROM) 656 def replace( 657 self, 658 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 659 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 660 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 661 ) -> DataFrame: 662 from sqlglot.dataframe.sql.functions import lit 663 664 old_values = None 665 new_df = self.copy() 666 all_columns = self._get_outer_select_columns(new_df.expression) 667 all_column_mapping = {column.alias_or_name: column for column in all_columns} 668 669 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 670 if isinstance(to_replace, dict): 671 old_values = list(to_replace) 672 new_values = list(to_replace.values()) 673 elif not old_values and isinstance(to_replace, list): 674 assert isinstance(value, list), "value must be a list since the replacements are a list" 675 assert len(to_replace) == len( 676 value 677 ), "the replacements and values must be the same length" 678 old_values = to_replace 679 new_values = value 680 else: 681 old_values = [to_replace] * len(columns) 682 new_values = [value] * len(columns) 683 old_values = [lit(value) for value in old_values] 684 new_values = [lit(value) for value in new_values] 685 686 replacement_mapping = {} 687 for column in columns: 688 expression = Column(None) 689 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 690 if i == 0: 691 expression = F.when(column == old_value, new_value) 692 else: 693 expression = expression.when(column == old_value, new_value) # type: ignore 694 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 695 column.expression.alias_or_name 696 ) 697 698 replacement_mapping = {**all_column_mapping, **replacement_mapping} 699 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 700 new_df = new_df.select(*replacement_columns) 701 return new_df 702 703 @operation(Operation.SELECT) 704 def withColumn(self, colName: str, col: Column) -> DataFrame: 705 col = self._ensure_and_normalize_col(col) 706 existing_col_names = self.expression.named_selects 707 existing_col_index = ( 708 existing_col_names.index(colName) if colName in existing_col_names else None 709 ) 710 if existing_col_index: 711 expression = self.expression.copy() 712 expression.expressions[existing_col_index] = col.expression 713 return self.copy(expression=expression) 714 return self.copy().select(col.alias(colName), append=True) 715 716 @operation(Operation.SELECT) 717 def withColumnRenamed(self, existing: str, new: str): 718 expression = self.expression.copy() 719 existing_columns = [ 720 expression 721 for expression in expression.expressions 722 if expression.alias_or_name == existing 723 ] 724 if not existing_columns: 725 raise ValueError("Tried to rename a column that doesn't exist") 726 for existing_column in existing_columns: 727 if isinstance(existing_column, exp.Column): 728 existing_column.replace(exp.alias_(existing_column.copy(), new)) 729 else: 730 existing_column.set("alias", exp.to_identifier(new)) 731 return self.copy(expression=expression) 732 733 @operation(Operation.SELECT) 734 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 735 all_columns = self._get_outer_select_columns(self.expression) 736 drop_cols = self._ensure_and_normalize_cols(cols) 737 new_columns = [ 738 col 739 for col in all_columns 740 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 741 ] 742 return self.copy().select(*new_columns, append=False) 743 744 @operation(Operation.LIMIT) 745 def limit(self, num: int) -> DataFrame: 746 return self.copy(expression=self.expression.limit(num)) 747 748 @operation(Operation.NO_OP) 749 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 750 parameter_list = ensure_list(parameters) 751 parameter_columns = ( 752 self._ensure_list_of_columns(parameter_list) 753 if parameters 754 else Column.ensure_cols([self.sequence_id]) 755 ) 756 return self._hint(name, parameter_columns) 757 758 @operation(Operation.NO_OP) 759 def repartition( 760 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 761 ) -> DataFrame: 762 num_partition_cols = self._ensure_list_of_columns(numPartitions) 763 columns = self._ensure_and_normalize_cols(cols) 764 args = num_partition_cols + columns 765 return self._hint("repartition", args) 766 767 @operation(Operation.NO_OP) 768 def coalesce(self, numPartitions: int) -> DataFrame: 769 num_partitions = Column.ensure_cols([numPartitions]) 770 return self._hint("coalesce", num_partitions) 771 772 @operation(Operation.NO_OP) 773 def cache(self) -> DataFrame: 774 return self._cache(storage_level="MEMORY_AND_DISK") 775 776 @operation(Operation.NO_OP) 777 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 778 """ 779 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 780 """ 781 return self._cache(storageLevel)
46 def __init__( 47 self, 48 spark: SparkSession, 49 expression: exp.Select, 50 branch_id: t.Optional[str] = None, 51 sequence_id: t.Optional[str] = None, 52 last_op: Operation = Operation.INIT, 53 pending_hints: t.Optional[t.List[exp.Expression]] = None, 54 output_expression_container: t.Optional[OutputExpressionContainer] = None, 55 **kwargs, 56 ): 57 self.spark = spark 58 self.expression = expression 59 self.branch_id = branch_id or self.spark._random_branch_id 60 self.sequence_id = sequence_id or self.spark._random_sequence_id 61 self.last_op = last_op 62 self.pending_hints = pending_hints or [] 63 self.output_expression_container = output_expression_container or exp.Select()
295 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 296 df = self._resolve_pending_hints() 297 select_expressions = df._get_select_expressions() 298 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 299 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 300 for expression_type, select_expression in select_expressions: 301 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 302 if optimize: 303 select_expression = optimize_func(select_expression) 304 select_expression = df._replace_cte_names_with_hashes(select_expression) 305 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 306 if expression_type == exp.Cache: 307 cache_table_name = df._create_hash_from_expression(select_expression) 308 cache_table = exp.to_table(cache_table_name) 309 original_alias_name = select_expression.args["cte_alias_name"] 310 311 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 312 cache_table_name 313 ) 314 sqlglot.schema.add_table( 315 cache_table_name, 316 { 317 expression.alias_or_name: expression.type.sql("spark") 318 for expression in select_expression.expressions 319 }, 320 ) 321 cache_storage_level = select_expression.args["cache_storage_level"] 322 options = [ 323 exp.Literal.string("storageLevel"), 324 exp.Literal.string(cache_storage_level), 325 ] 326 expression = exp.Cache( 327 this=cache_table, expression=select_expression, lazy=True, options=options 328 ) 329 # We will drop the "view" if it exists before running the cache table 330 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 331 elif expression_type == exp.Create: 332 expression = df.output_expression_container.copy() 333 expression.set("expression", select_expression) 334 elif expression_type == exp.Insert: 335 expression = df.output_expression_container.copy() 336 select_without_ctes = select_expression.copy() 337 select_without_ctes.set("with", None) 338 expression.set("expression", select_without_ctes) 339 if select_expression.ctes: 340 expression.set("with", exp.With(expressions=select_expression.ctes)) 341 elif expression_type == exp.Select: 342 expression = select_expression 343 else: 344 raise ValueError(f"Invalid expression type: {expression_type}") 345 output_expressions.append(expression) 346 347 return [ 348 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 349 ]
354 @operation(Operation.SELECT) 355 def select(self, *cols, **kwargs) -> DataFrame: 356 cols = self._ensure_and_normalize_cols(cols) 357 kwargs["append"] = kwargs.get("append", False) 358 if self.expression.args.get("joins"): 359 ambiguous_cols = [col for col in cols if not col.column_expression.table] 360 if ambiguous_cols: 361 join_table_identifiers = [ 362 x.this for x in get_tables_from_expression_with_join(self.expression) 363 ] 364 cte_names_in_join = [x.this for x in join_table_identifiers] 365 for ambiguous_col in ambiguous_cols: 366 ctes_with_column = [ 367 cte 368 for cte in self.expression.ctes 369 if cte.alias_or_name in cte_names_in_join 370 and ambiguous_col.alias_or_name in cte.this.named_selects 371 ] 372 # If the select column does not specify a table and there is a join 373 # then we assume they are referring to the left table 374 if len(ctes_with_column) > 1: 375 table_identifier = self.expression.args["from"].args["expressions"][0].this 376 else: 377 table_identifier = ctes_with_column[0].args["alias"].this 378 ambiguous_col.expression.set("table", table_identifier) 379 return self.copy( 380 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 381 )
383 @operation(Operation.NO_OP) 384 def alias(self, name: str, **kwargs) -> DataFrame: 385 new_sequence_id = self.spark._random_sequence_id 386 df = self.copy() 387 for join_hint in df.pending_join_hints: 388 for expression in join_hint.expressions: 389 if expression.alias_or_name == self.sequence_id: 390 expression.set("this", Column.ensure_col(new_sequence_id).expression) 391 df.spark._add_alias_to_mapping(name, new_sequence_id) 392 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
411 @operation(Operation.FROM) 412 def join( 413 self, 414 other_df: DataFrame, 415 on: t.Union[str, t.List[str], Column, t.List[Column]], 416 how: str = "inner", 417 **kwargs, 418 ) -> DataFrame: 419 other_df = other_df._convert_leaf_to_cte() 420 pre_join_self_latest_cte_name = self.latest_cte_name 421 columns = self._ensure_and_normalize_cols(on) 422 join_type = how.replace("_", " ") 423 if isinstance(columns[0].expression, exp.Column): 424 join_columns = [ 425 Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns 426 ] 427 join_clause = functools.reduce( 428 lambda x, y: x & y, 429 [ 430 col.copy().set_table_name(pre_join_self_latest_cte_name) 431 == col.copy().set_table_name(other_df.latest_cte_name) 432 for col in columns 433 ], 434 ) 435 else: 436 if len(columns) > 1: 437 columns = [functools.reduce(lambda x, y: x & y, columns)] 438 join_clause = columns[0] 439 join_columns = [ 440 Column(x).set_table_name(pre_join_self_latest_cte_name) 441 if i % 2 == 0 442 else Column(x).set_table_name(other_df.latest_cte_name) 443 for i, x in enumerate(join_clause.expression.find_all(exp.Column)) 444 ] 445 self_columns = [ 446 column.set_table_name(pre_join_self_latest_cte_name, copy=True) 447 for column in self._get_outer_select_columns(self) 448 ] 449 other_columns = [ 450 column.set_table_name(other_df.latest_cte_name, copy=True) 451 for column in self._get_outer_select_columns(other_df) 452 ] 453 column_value_mapping = { 454 column.alias_or_name 455 if not isinstance(column.expression.this, exp.Star) 456 else column.sql(): column 457 for column in other_columns + self_columns + join_columns 458 } 459 all_columns = [ 460 column_value_mapping[name] 461 for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} 462 ] 463 new_df = self.copy( 464 expression=self.expression.join( 465 other_df.latest_cte_name, on=join_clause.expression, join_type=join_type 466 ) 467 ) 468 new_df.expression = new_df._add_ctes_to_expression( 469 new_df.expression, other_df.expression.ctes 470 ) 471 new_df.pending_hints.extend(other_df.pending_hints) 472 new_df = new_df.select.__wrapped__(new_df, *all_columns) 473 return new_df
475 @operation(Operation.ORDER_BY) 476 def orderBy( 477 self, 478 *cols: t.Union[str, Column], 479 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 480 ) -> DataFrame: 481 """ 482 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 483 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 484 is unlikely to come up. 485 """ 486 columns = self._ensure_and_normalize_cols(cols) 487 pre_ordered_col_indexes = [ 488 x 489 for x in [ 490 i if isinstance(col.expression, exp.Ordered) else None 491 for i, col in enumerate(columns) 492 ] 493 if x is not None 494 ] 495 if ascending is None: 496 ascending = [True] * len(columns) 497 elif not isinstance(ascending, list): 498 ascending = [ascending] * len(columns) 499 ascending = [bool(x) for i, x in enumerate(ascending)] 500 assert len(columns) == len( 501 ascending 502 ), "The length of items in ascending must equal the number of columns provided" 503 col_and_ascending = list(zip(columns, ascending)) 504 order_by_columns = [ 505 exp.Ordered(this=col.expression, desc=not asc) 506 if i not in pre_ordered_col_indexes 507 else columns[i].column_expression 508 for i, (col, asc) in enumerate(col_and_ascending) 509 ] 510 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
475 @operation(Operation.ORDER_BY) 476 def orderBy( 477 self, 478 *cols: t.Union[str, Column], 479 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 480 ) -> DataFrame: 481 """ 482 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 483 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 484 is unlikely to come up. 485 """ 486 columns = self._ensure_and_normalize_cols(cols) 487 pre_ordered_col_indexes = [ 488 x 489 for x in [ 490 i if isinstance(col.expression, exp.Ordered) else None 491 for i, col in enumerate(columns) 492 ] 493 if x is not None 494 ] 495 if ascending is None: 496 ascending = [True] * len(columns) 497 elif not isinstance(ascending, list): 498 ascending = [ascending] * len(columns) 499 ascending = [bool(x) for i, x in enumerate(ascending)] 500 assert len(columns) == len( 501 ascending 502 ), "The length of items in ascending must equal the number of columns provided" 503 col_and_ascending = list(zip(columns, ascending)) 504 order_by_columns = [ 505 exp.Ordered(this=col.expression, desc=not asc) 506 if i not in pre_ordered_col_indexes 507 else columns[i].column_expression 508 for i, (col, asc) in enumerate(col_and_ascending) 509 ] 510 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
520 @operation(Operation.FROM) 521 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 522 l_columns = self.columns 523 r_columns = other.columns 524 if not allowMissingColumns: 525 l_expressions = l_columns 526 r_expressions = l_columns 527 else: 528 l_expressions = [] 529 r_expressions = [] 530 r_columns_unused = copy(r_columns) 531 for l_column in l_columns: 532 l_expressions.append(l_column) 533 if l_column in r_columns: 534 r_expressions.append(l_column) 535 r_columns_unused.remove(l_column) 536 else: 537 r_expressions.append(exp.alias_(exp.Null(), l_column)) 538 for r_column in r_columns_unused: 539 l_expressions.append(exp.alias_(exp.Null(), r_column)) 540 r_expressions.append(r_column) 541 r_df = ( 542 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 543 ) 544 l_df = self.copy() 545 if allowMissingColumns: 546 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 547 return l_df._set_operation(exp.Union, r_df, False)
565 @operation(Operation.SELECT) 566 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 567 if not subset: 568 return self.distinct() 569 column_names = ensure_list(subset) 570 window = Window.partitionBy(*column_names).orderBy(*column_names) 571 return ( 572 self.copy() 573 .withColumn("row_num", F.row_number().over(window)) 574 .where(F.col("row_num") == F.lit(1)) 575 .drop("row_num") 576 )
578 @operation(Operation.FROM) 579 def dropna( 580 self, 581 how: str = "any", 582 thresh: t.Optional[int] = None, 583 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 584 ) -> DataFrame: 585 minimum_non_null = thresh or 0 # will be determined later if thresh is null 586 new_df = self.copy() 587 all_columns = self._get_outer_select_columns(new_df.expression) 588 if subset: 589 null_check_columns = self._ensure_and_normalize_cols(subset) 590 else: 591 null_check_columns = all_columns 592 if thresh is None: 593 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 594 else: 595 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 596 if minimum_num_nulls > len(null_check_columns): 597 raise RuntimeError( 598 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 599 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 600 ) 601 if_null_checks = [ 602 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 603 ] 604 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 605 num_nulls = nulls_added_together.alias("num_nulls") 606 new_df = new_df.select(num_nulls, append=True) 607 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 608 final_df = filtered_df.select(*all_columns) 609 return final_df
611 @operation(Operation.FROM) 612 def fillna( 613 self, 614 value: t.Union[ColumnLiterals], 615 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 616 ) -> DataFrame: 617 """ 618 Functionality Difference: If you provide a value to replace a null and that type conflicts 619 with the type of the column then PySpark will just ignore your replacement. 620 This will try to cast them to be the same in some cases. So they won't always match. 621 Best to not mix types so make sure replacement is the same type as the column 622 623 Possibility for improvement: Use `typeof` function to get the type of the column 624 and check if it matches the type of the value provided. If not then make it null. 625 """ 626 from sqlglot.dataframe.sql.functions import lit 627 628 values = None 629 columns = None 630 new_df = self.copy() 631 all_columns = self._get_outer_select_columns(new_df.expression) 632 all_column_mapping = {column.alias_or_name: column for column in all_columns} 633 if isinstance(value, dict): 634 values = list(value.values()) 635 columns = self._ensure_and_normalize_cols(list(value)) 636 if not columns: 637 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 638 if not values: 639 values = [value] * len(columns) 640 value_columns = [lit(value) for value in values] 641 642 null_replacement_mapping = { 643 column.alias_or_name: ( 644 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 645 ) 646 for column, value in zip(columns, value_columns) 647 } 648 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 649 null_replacement_columns = [ 650 null_replacement_mapping[column.alias_or_name] for column in all_columns 651 ] 652 new_df = new_df.select(*null_replacement_columns) 653 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
655 @operation(Operation.FROM) 656 def replace( 657 self, 658 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 659 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 660 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 661 ) -> DataFrame: 662 from sqlglot.dataframe.sql.functions import lit 663 664 old_values = None 665 new_df = self.copy() 666 all_columns = self._get_outer_select_columns(new_df.expression) 667 all_column_mapping = {column.alias_or_name: column for column in all_columns} 668 669 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 670 if isinstance(to_replace, dict): 671 old_values = list(to_replace) 672 new_values = list(to_replace.values()) 673 elif not old_values and isinstance(to_replace, list): 674 assert isinstance(value, list), "value must be a list since the replacements are a list" 675 assert len(to_replace) == len( 676 value 677 ), "the replacements and values must be the same length" 678 old_values = to_replace 679 new_values = value 680 else: 681 old_values = [to_replace] * len(columns) 682 new_values = [value] * len(columns) 683 old_values = [lit(value) for value in old_values] 684 new_values = [lit(value) for value in new_values] 685 686 replacement_mapping = {} 687 for column in columns: 688 expression = Column(None) 689 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 690 if i == 0: 691 expression = F.when(column == old_value, new_value) 692 else: 693 expression = expression.when(column == old_value, new_value) # type: ignore 694 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 695 column.expression.alias_or_name 696 ) 697 698 replacement_mapping = {**all_column_mapping, **replacement_mapping} 699 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 700 new_df = new_df.select(*replacement_columns) 701 return new_df
703 @operation(Operation.SELECT) 704 def withColumn(self, colName: str, col: Column) -> DataFrame: 705 col = self._ensure_and_normalize_col(col) 706 existing_col_names = self.expression.named_selects 707 existing_col_index = ( 708 existing_col_names.index(colName) if colName in existing_col_names else None 709 ) 710 if existing_col_index: 711 expression = self.expression.copy() 712 expression.expressions[existing_col_index] = col.expression 713 return self.copy(expression=expression) 714 return self.copy().select(col.alias(colName), append=True)
716 @operation(Operation.SELECT) 717 def withColumnRenamed(self, existing: str, new: str): 718 expression = self.expression.copy() 719 existing_columns = [ 720 expression 721 for expression in expression.expressions 722 if expression.alias_or_name == existing 723 ] 724 if not existing_columns: 725 raise ValueError("Tried to rename a column that doesn't exist") 726 for existing_column in existing_columns: 727 if isinstance(existing_column, exp.Column): 728 existing_column.replace(exp.alias_(existing_column.copy(), new)) 729 else: 730 existing_column.set("alias", exp.to_identifier(new)) 731 return self.copy(expression=expression)
733 @operation(Operation.SELECT) 734 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 735 all_columns = self._get_outer_select_columns(self.expression) 736 drop_cols = self._ensure_and_normalize_cols(cols) 737 new_columns = [ 738 col 739 for col in all_columns 740 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 741 ] 742 return self.copy().select(*new_columns, append=False)
748 @operation(Operation.NO_OP) 749 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 750 parameter_list = ensure_list(parameters) 751 parameter_columns = ( 752 self._ensure_list_of_columns(parameter_list) 753 if parameters 754 else Column.ensure_cols([self.sequence_id]) 755 ) 756 return self._hint(name, parameter_columns)
758 @operation(Operation.NO_OP) 759 def repartition( 760 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 761 ) -> DataFrame: 762 num_partition_cols = self._ensure_list_of_columns(numPartitions) 763 columns = self._ensure_and_normalize_cols(cols) 764 args = num_partition_cols + columns 765 return self._hint("repartition", args)
776 @operation(Operation.NO_OP) 777 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 778 """ 779 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 780 """ 781 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 if isinstance(expression, Column): 19 expression = expression.expression # type: ignore 20 elif expression is None or not isinstance(expression, (str, exp.Expression)): 21 expression = self._lit(expression).expression # type: ignore 22 23 expression = sqlglot.maybe_parse(expression, dialect="spark") 24 if expression is None: 25 raise ValueError(f"Could not parse {expression}") 26 self.expression: exp.Expression = expression 27 28 def __repr__(self): 29 return repr(self.expression) 30 31 def __hash__(self): 32 return hash(self.expression) 33 34 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 35 return self.binary_op(exp.EQ, other) 36 37 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 38 return self.binary_op(exp.NEQ, other) 39 40 def __gt__(self, other: ColumnOrLiteral) -> Column: 41 return self.binary_op(exp.GT, other) 42 43 def __ge__(self, other: ColumnOrLiteral) -> Column: 44 return self.binary_op(exp.GTE, other) 45 46 def __lt__(self, other: ColumnOrLiteral) -> Column: 47 return self.binary_op(exp.LT, other) 48 49 def __le__(self, other: ColumnOrLiteral) -> Column: 50 return self.binary_op(exp.LTE, other) 51 52 def __and__(self, other: ColumnOrLiteral) -> Column: 53 return self.binary_op(exp.And, other) 54 55 def __or__(self, other: ColumnOrLiteral) -> Column: 56 return self.binary_op(exp.Or, other) 57 58 def __mod__(self, other: ColumnOrLiteral) -> Column: 59 return self.binary_op(exp.Mod, other) 60 61 def __add__(self, other: ColumnOrLiteral) -> Column: 62 return self.binary_op(exp.Add, other) 63 64 def __sub__(self, other: ColumnOrLiteral) -> Column: 65 return self.binary_op(exp.Sub, other) 66 67 def __mul__(self, other: ColumnOrLiteral) -> Column: 68 return self.binary_op(exp.Mul, other) 69 70 def __truediv__(self, other: ColumnOrLiteral) -> Column: 71 return self.binary_op(exp.Div, other) 72 73 def __div__(self, other: ColumnOrLiteral) -> Column: 74 return self.binary_op(exp.Div, other) 75 76 def __neg__(self) -> Column: 77 return self.unary_op(exp.Neg) 78 79 def __radd__(self, other: ColumnOrLiteral) -> Column: 80 return self.inverse_binary_op(exp.Add, other) 81 82 def __rsub__(self, other: ColumnOrLiteral) -> Column: 83 return self.inverse_binary_op(exp.Sub, other) 84 85 def __rmul__(self, other: ColumnOrLiteral) -> Column: 86 return self.inverse_binary_op(exp.Mul, other) 87 88 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 89 return self.inverse_binary_op(exp.Div, other) 90 91 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 92 return self.inverse_binary_op(exp.Div, other) 93 94 def __rmod__(self, other: ColumnOrLiteral) -> Column: 95 return self.inverse_binary_op(exp.Mod, other) 96 97 def __pow__(self, power: ColumnOrLiteral, modulo=None): 98 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 99 100 def __rpow__(self, power: ColumnOrLiteral): 101 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 102 103 def __invert__(self): 104 return self.unary_op(exp.Not) 105 106 def __rand__(self, other: ColumnOrLiteral) -> Column: 107 return self.inverse_binary_op(exp.And, other) 108 109 def __ror__(self, other: ColumnOrLiteral) -> Column: 110 return self.inverse_binary_op(exp.Or, other) 111 112 @classmethod 113 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 114 return cls(value) 115 116 @classmethod 117 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 118 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 119 120 @classmethod 121 def _lit(cls, value: ColumnOrLiteral) -> Column: 122 if isinstance(value, dict): 123 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 124 return cls(exp.Struct(expressions=columns)) 125 return cls(exp.convert(value)) 126 127 @classmethod 128 def invoke_anonymous_function( 129 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 130 ) -> Column: 131 columns = [] if column is None else [cls.ensure_col(column)] 132 column_args = [cls.ensure_col(arg) for arg in args] 133 expressions = [x.expression for x in columns + column_args] 134 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 135 return Column(new_expression) 136 137 @classmethod 138 def invoke_expression_over_column( 139 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 140 ) -> Column: 141 ensured_column = None if column is None else cls.ensure_col(column) 142 ensure_expression_values = { 143 k: [Column.ensure_col(x).expression for x in v] 144 if is_iterable(v) 145 else Column.ensure_col(v).expression 146 for k, v in kwargs.items() 147 if v is not None 148 } 149 new_expression = ( 150 callable_expression(**ensure_expression_values) 151 if ensured_column is None 152 else callable_expression( 153 this=ensured_column.column_expression, **ensure_expression_values 154 ) 155 ) 156 return Column(new_expression) 157 158 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 159 return Column( 160 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 161 ) 162 163 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 164 return Column( 165 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 166 ) 167 168 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 169 return Column(klass(this=self.column_expression, **kwargs)) 170 171 @property 172 def is_alias(self): 173 return isinstance(self.expression, exp.Alias) 174 175 @property 176 def is_column(self): 177 return isinstance(self.expression, exp.Column) 178 179 @property 180 def column_expression(self) -> exp.Column: 181 return self.expression.unalias() 182 183 @property 184 def alias_or_name(self) -> str: 185 return self.expression.alias_or_name 186 187 @classmethod 188 def ensure_literal(cls, value) -> Column: 189 from sqlglot.dataframe.sql.functions import lit 190 191 if isinstance(value, cls): 192 value = value.expression 193 if not isinstance(value, exp.Literal): 194 return lit(value) 195 return Column(value) 196 197 def copy(self) -> Column: 198 return Column(self.expression.copy()) 199 200 def set_table_name(self, table_name: str, copy=False) -> Column: 201 expression = self.expression.copy() if copy else self.expression 202 expression.set("table", exp.to_identifier(table_name)) 203 return Column(expression) 204 205 def sql(self, **kwargs) -> str: 206 return self.expression.sql(**{"dialect": "spark", **kwargs}) 207 208 def alias(self, name: str) -> Column: 209 new_expression = exp.alias_(self.column_expression, name) 210 return Column(new_expression) 211 212 def asc(self) -> Column: 213 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 214 return Column(new_expression) 215 216 def desc(self) -> Column: 217 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 218 return Column(new_expression) 219 220 asc_nulls_first = asc 221 222 def asc_nulls_last(self) -> Column: 223 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 224 return Column(new_expression) 225 226 def desc_nulls_first(self) -> Column: 227 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 228 return Column(new_expression) 229 230 desc_nulls_last = desc 231 232 def when(self, condition: Column, value: t.Any) -> Column: 233 from sqlglot.dataframe.sql.functions import when 234 235 column_with_if = when(condition, value) 236 if not isinstance(self.expression, exp.Case): 237 return column_with_if 238 new_column = self.copy() 239 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 240 return new_column 241 242 def otherwise(self, value: t.Any) -> Column: 243 from sqlglot.dataframe.sql.functions import lit 244 245 true_value = value if isinstance(value, Column) else lit(value) 246 new_column = self.copy() 247 new_column.expression.set("default", true_value.column_expression) 248 return new_column 249 250 def isNull(self) -> Column: 251 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 252 return Column(new_expression) 253 254 def isNotNull(self) -> Column: 255 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 256 return Column(new_expression) 257 258 def cast(self, dataType: t.Union[str, DataType]): 259 """ 260 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 261 Sqlglot doesn't currently replicate this class so it only accepts a string 262 """ 263 if isinstance(dataType, DataType): 264 dataType = dataType.simpleString() 265 return Column(exp.cast(self.column_expression, dataType, dialect="spark")) 266 267 def startswith(self, value: t.Union[str, Column]) -> Column: 268 value = self._lit(value) if not isinstance(value, Column) else value 269 return self.invoke_anonymous_function(self, "STARTSWITH", value) 270 271 def endswith(self, value: t.Union[str, Column]) -> Column: 272 value = self._lit(value) if not isinstance(value, Column) else value 273 return self.invoke_anonymous_function(self, "ENDSWITH", value) 274 275 def rlike(self, regexp: str) -> Column: 276 return self.invoke_expression_over_column( 277 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 278 ) 279 280 def like(self, other: str): 281 return self.invoke_expression_over_column( 282 self, exp.Like, expression=self._lit(other).expression 283 ) 284 285 def ilike(self, other: str): 286 return self.invoke_expression_over_column( 287 self, exp.ILike, expression=self._lit(other).expression 288 ) 289 290 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 291 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 292 length = self._lit(length) if not isinstance(length, Column) else length 293 return Column.invoke_expression_over_column( 294 self, exp.Substring, start=startPos.expression, length=length.expression 295 ) 296 297 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 298 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 299 expressions = [self._lit(x).expression for x in columns] 300 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 301 302 def between( 303 self, 304 lowerBound: t.Union[ColumnOrLiteral], 305 upperBound: t.Union[ColumnOrLiteral], 306 ) -> Column: 307 lower_bound_exp = ( 308 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 309 ) 310 upper_bound_exp = ( 311 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 312 ) 313 return Column( 314 exp.Between( 315 this=self.column_expression, 316 low=lower_bound_exp.expression, 317 high=upper_bound_exp.expression, 318 ) 319 ) 320 321 def over(self, window: WindowSpec) -> Column: 322 window_expression = window.expression.copy() 323 window_expression.set("this", self.column_expression) 324 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 if isinstance(expression, Column): 19 expression = expression.expression # type: ignore 20 elif expression is None or not isinstance(expression, (str, exp.Expression)): 21 expression = self._lit(expression).expression # type: ignore 22 23 expression = sqlglot.maybe_parse(expression, dialect="spark") 24 if expression is None: 25 raise ValueError(f"Could not parse {expression}") 26 self.expression: exp.Expression = expression
127 @classmethod 128 def invoke_anonymous_function( 129 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 130 ) -> Column: 131 columns = [] if column is None else [cls.ensure_col(column)] 132 column_args = [cls.ensure_col(arg) for arg in args] 133 expressions = [x.expression for x in columns + column_args] 134 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 135 return Column(new_expression)
137 @classmethod 138 def invoke_expression_over_column( 139 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 140 ) -> Column: 141 ensured_column = None if column is None else cls.ensure_col(column) 142 ensure_expression_values = { 143 k: [Column.ensure_col(x).expression for x in v] 144 if is_iterable(v) 145 else Column.ensure_col(v).expression 146 for k, v in kwargs.items() 147 if v is not None 148 } 149 new_expression = ( 150 callable_expression(**ensure_expression_values) 151 if ensured_column is None 152 else callable_expression( 153 this=ensured_column.column_expression, **ensure_expression_values 154 ) 155 ) 156 return Column(new_expression)
232 def when(self, condition: Column, value: t.Any) -> Column: 233 from sqlglot.dataframe.sql.functions import when 234 235 column_with_if = when(condition, value) 236 if not isinstance(self.expression, exp.Case): 237 return column_with_if 238 new_column = self.copy() 239 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 240 return new_column
258 def cast(self, dataType: t.Union[str, DataType]): 259 """ 260 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 261 Sqlglot doesn't currently replicate this class so it only accepts a string 262 """ 263 if isinstance(dataType, DataType): 264 dataType = dataType.simpleString() 265 return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
290 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 291 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 292 length = self._lit(length) if not isinstance(length, Column) else length 293 return Column.invoke_expression_over_column( 294 self, exp.Substring, start=startPos.expression, length=length.expression 295 )
297 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 298 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 299 expressions = [self._lit(x).expression for x in columns] 300 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
302 def between( 303 self, 304 lowerBound: t.Union[ColumnOrLiteral], 305 upperBound: t.Union[ColumnOrLiteral], 306 ) -> Column: 307 lower_bound_exp = ( 308 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 309 ) 310 upper_bound_exp = ( 311 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 312 ) 313 return Column( 314 exp.Between( 315 this=self.column_expression, 316 low=lower_bound_exp.expression, 317 high=upper_bound_exp.expression, 318 ) 319 )
784class DataFrameNaFunctions: 785 def __init__(self, df: DataFrame): 786 self.df = df 787 788 def drop( 789 self, 790 how: str = "any", 791 thresh: t.Optional[int] = None, 792 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 793 ) -> DataFrame: 794 return self.df.dropna(how=how, thresh=thresh, subset=subset) 795 796 def fill( 797 self, 798 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 799 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 800 ) -> DataFrame: 801 return self.df.fillna(value=value, subset=subset) 802 803 def replace( 804 self, 805 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 806 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 807 subset: t.Optional[t.Union[str, t.List[str]]] = None, 808 ) -> DataFrame: 809 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
803 def replace( 804 self, 805 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 806 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 807 subset: t.Optional[t.Union[str, t.List[str]]] = None, 808 ) -> DataFrame: 809 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 return self.expression.sql(dialect="spark", **kwargs) 53 54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec 64 65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec 77 78 def _calc_start_end( 79 self, start: int, end: int 80 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 81 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 82 "start_side": None, 83 "end_side": None, 84 } 85 if start == Window.currentRow: 86 kwargs["start"] = "CURRENT ROW" 87 else: 88 kwargs = { 89 **kwargs, 90 **{ 91 "start_side": "PRECEDING", 92 "start": "UNBOUNDED" 93 if start <= Window.unboundedPreceding 94 else F.lit(start).expression, 95 }, 96 } 97 if end == Window.currentRow: 98 kwargs["end"] = "CURRENT ROW" 99 else: 100 kwargs = { 101 **kwargs, 102 **{ 103 "end_side": "FOLLOWING", 104 "end": "UNBOUNDED" 105 if end >= Window.unboundedFollowing 106 else F.lit(end).expression, 107 }, 108 } 109 return kwargs 110 111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec 122 123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec
65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec
111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec
123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 22 sqlglot.schema.add_table(tableName) 23 return DataFrame( 24 self.spark, 25 exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), 26 )
29class DataFrameWriter: 30 def __init__( 31 self, 32 df: DataFrame, 33 spark: t.Optional[SparkSession] = None, 34 mode: t.Optional[str] = None, 35 by_name: bool = False, 36 ): 37 self._df = df 38 self._spark = spark or df.spark 39 self._mode = mode 40 self._by_name = by_name 41 42 def copy(self, **kwargs) -> DataFrameWriter: 43 return DataFrameWriter( 44 **{ 45 k[1:] if k.startswith("_") else k: v 46 for k, v in object_to_dict(self, **kwargs).items() 47 } 48 ) 49 50 def sql(self, **kwargs) -> t.List[str]: 51 return self._df.sql(**kwargs) 52 53 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 54 return self.copy(_mode=saveMode) 55 56 @property 57 def byName(self): 58 return self.copy(by_name=True) 59 60 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 61 output_expression_container = exp.Insert( 62 **{ 63 "this": exp.to_table(tableName), 64 "overwrite": overwrite, 65 } 66 ) 67 df = self._df.copy(output_expression_container=output_expression_container) 68 if self._by_name: 69 columns = sqlglot.schema.column_names(tableName, only_visible=True) 70 df = df._convert_leaf_to_cte().select(*columns) 71 72 return self.copy(_df=df) 73 74 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 75 if format is not None: 76 raise NotImplementedError("Providing Format in the save as table is not supported") 77 exists, replace, mode = None, None, mode or str(self._mode) 78 if mode == "append": 79 return self.insertInto(name) 80 if mode == "ignore": 81 exists = True 82 if mode == "overwrite": 83 replace = True 84 output_expression_container = exp.Create( 85 this=exp.to_table(name), 86 kind="TABLE", 87 exists=exists, 88 replace=replace, 89 ) 90 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
60 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 61 output_expression_container = exp.Insert( 62 **{ 63 "this": exp.to_table(tableName), 64 "overwrite": overwrite, 65 } 66 ) 67 df = self._df.copy(output_expression_container=output_expression_container) 68 if self._by_name: 69 columns = sqlglot.schema.column_names(tableName, only_visible=True) 70 df = df._convert_leaf_to_cte().select(*columns) 71 72 return self.copy(_df=df)
74 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 75 if format is not None: 76 raise NotImplementedError("Providing Format in the save as table is not supported") 77 exists, replace, mode = None, None, mode or str(self._mode) 78 if mode == "append": 79 return self.insertInto(name) 80 if mode == "ignore": 81 exists = True 82 if mode == "overwrite": 83 replace = True 84 output_expression_container = exp.Create( 85 this=exp.to_table(name), 86 kind="TABLE", 87 exists=exists, 88 replace=replace, 89 ) 90 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))