sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
20class SparkSession: 21 known_ids: t.ClassVar[t.Set[str]] = set() 22 known_branch_ids: t.ClassVar[t.Set[str]] = set() 23 known_sequence_ids: t.ClassVar[t.Set[str]] = set() 24 name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) 25 26 def __init__(self): 27 self.incrementing_id = 1 28 29 def __getattr__(self, name: str) -> SparkSession: 30 return self 31 32 def __call__(self, *args, **kwargs) -> SparkSession: 33 return self 34 35 @property 36 def read(self) -> DataFrameReader: 37 return DataFrameReader(self) 38 39 def table(self, tableName: str) -> DataFrame: 40 return self.read.table(tableName) 41 42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 expressions=[ 92 exp.Values( 93 expressions=data_expressions, 94 alias=exp.TableAlias( 95 this=exp.to_identifier(self._auto_incrementing_name), 96 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 97 ), 98 ), 99 ], 100 ), 101 } 102 103 sel_expression = exp.Select(**select_kwargs) 104 return DataFrame(self, sel_expression) 105 106 def sql(self, sqlQuery: str) -> DataFrame: 107 expression = sqlglot.parse_one(sqlQuery, read="spark") 108 if isinstance(expression, exp.Select): 109 df = DataFrame(self, expression) 110 df = df._convert_leaf_to_cte() 111 elif isinstance(expression, (exp.Create, exp.Insert)): 112 select_expression = expression.expression.copy() 113 if isinstance(expression, exp.Insert): 114 select_expression.set("with", expression.args.get("with")) 115 expression.set("with", None) 116 del expression.args["expression"] 117 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 118 df = df._convert_leaf_to_cte() 119 else: 120 raise ValueError( 121 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 122 ) 123 return df 124 125 @property 126 def _auto_incrementing_name(self) -> str: 127 name = f"a{self.incrementing_id}" 128 self.incrementing_id += 1 129 return name 130 131 @property 132 def _random_name(self) -> str: 133 return "r" + uuid.uuid4().hex 134 135 @property 136 def _random_branch_id(self) -> str: 137 id = self._random_id 138 self.known_branch_ids.add(id) 139 return id 140 141 @property 142 def _random_sequence_id(self): 143 id = self._random_id 144 self.known_sequence_ids.add(id) 145 return id 146 147 @property 148 def _random_id(self) -> str: 149 id = self._random_name 150 self.known_ids.add(id) 151 return id 152 153 @property 154 def _join_hint_names(self) -> t.Set[str]: 155 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 156 157 def _add_alias_to_mapping(self, name: str, sequence_id: str): 158 self.name_to_sequence_id_mapping[name].append(sequence_id)
42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 expressions=[ 92 exp.Values( 93 expressions=data_expressions, 94 alias=exp.TableAlias( 95 this=exp.to_identifier(self._auto_incrementing_name), 96 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 97 ), 98 ), 99 ], 100 ), 101 } 102 103 sel_expression = exp.Select(**select_kwargs) 104 return DataFrame(self, sel_expression)
106 def sql(self, sqlQuery: str) -> DataFrame: 107 expression = sqlglot.parse_one(sqlQuery, read="spark") 108 if isinstance(expression, exp.Select): 109 df = DataFrame(self, expression) 110 df = df._convert_leaf_to_cte() 111 elif isinstance(expression, (exp.Create, exp.Insert)): 112 select_expression = expression.expression.copy() 113 if isinstance(expression, exp.Insert): 114 select_expression.set("with", expression.args.get("with")) 115 expression.set("with", None) 116 del expression.args["expression"] 117 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 118 df = df._convert_leaf_to_cte() 119 else: 120 raise ValueError( 121 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 122 ) 123 return df
46class DataFrame: 47 def __init__( 48 self, 49 spark: SparkSession, 50 expression: exp.Select, 51 branch_id: t.Optional[str] = None, 52 sequence_id: t.Optional[str] = None, 53 last_op: Operation = Operation.INIT, 54 pending_hints: t.Optional[t.List[exp.Expression]] = None, 55 output_expression_container: t.Optional[OutputExpressionContainer] = None, 56 **kwargs, 57 ): 58 self.spark = spark 59 self.expression = expression 60 self.branch_id = branch_id or self.spark._random_branch_id 61 self.sequence_id = sequence_id or self.spark._random_sequence_id 62 self.last_op = last_op 63 self.pending_hints = pending_hints or [] 64 self.output_expression_container = output_expression_container or exp.Select() 65 66 def __getattr__(self, column_name: str) -> Column: 67 return self[column_name] 68 69 def __getitem__(self, column_name: str) -> Column: 70 column_name = f"{self.branch_id}.{column_name}" 71 return Column(column_name) 72 73 def __copy__(self): 74 return self.copy() 75 76 @property 77 def sparkSession(self): 78 return self.spark 79 80 @property 81 def write(self): 82 return DataFrameWriter(self) 83 84 @property 85 def latest_cte_name(self) -> str: 86 if not self.expression.ctes: 87 from_exp = self.expression.args["from"] 88 if from_exp.alias_or_name: 89 return from_exp.alias_or_name 90 table_alias = from_exp.find(exp.TableAlias) 91 if not table_alias: 92 raise RuntimeError( 93 f"Could not find an alias name for this expression: {self.expression}" 94 ) 95 return table_alias.alias_or_name 96 return self.expression.ctes[-1].alias 97 98 @property 99 def pending_join_hints(self): 100 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 101 102 @property 103 def pending_partition_hints(self): 104 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 105 106 @property 107 def columns(self) -> t.List[str]: 108 return self.expression.named_selects 109 110 @property 111 def na(self) -> DataFrameNaFunctions: 112 return DataFrameNaFunctions(self) 113 114 def _replace_cte_names_with_hashes(self, expression: exp.Select): 115 replacement_mapping = {} 116 for cte in expression.ctes: 117 old_name_id = cte.args["alias"].this 118 new_hashed_id = exp.to_identifier( 119 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 120 ) 121 replacement_mapping[old_name_id] = new_hashed_id 122 expression = expression.transform(replace_id_value, replacement_mapping) 123 return expression 124 125 def _create_cte_from_expression( 126 self, 127 expression: exp.Expression, 128 branch_id: t.Optional[str] = None, 129 sequence_id: t.Optional[str] = None, 130 **kwargs, 131 ) -> t.Tuple[exp.CTE, str]: 132 name = self.spark._random_name 133 expression_to_cte = expression.copy() 134 expression_to_cte.set("with", None) 135 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 136 cte.set("branch_id", branch_id or self.branch_id) 137 cte.set("sequence_id", sequence_id or self.sequence_id) 138 return cte, name 139 140 @t.overload 141 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 142 ... 143 144 @t.overload 145 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 146 ... 147 148 def _ensure_list_of_columns(self, cols): 149 return Column.ensure_cols(ensure_list(cols)) 150 151 def _ensure_and_normalize_cols(self, cols): 152 cols = self._ensure_list_of_columns(cols) 153 normalize(self.spark, self.expression, cols) 154 return cols 155 156 def _ensure_and_normalize_col(self, col): 157 col = Column.ensure_col(col) 158 normalize(self.spark, self.expression, col) 159 return col 160 161 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 162 df = self._resolve_pending_hints() 163 sequence_id = sequence_id or df.sequence_id 164 expression = df.expression.copy() 165 cte_expression, cte_name = df._create_cte_from_expression( 166 expression=expression, sequence_id=sequence_id 167 ) 168 new_expression = df._add_ctes_to_expression( 169 exp.Select(), expression.ctes + [cte_expression] 170 ) 171 sel_columns = df._get_outer_select_columns(cte_expression) 172 new_expression = new_expression.from_(cte_name).select( 173 *[x.alias_or_name for x in sel_columns] 174 ) 175 return df.copy(expression=new_expression, sequence_id=sequence_id) 176 177 def _resolve_pending_hints(self) -> DataFrame: 178 df = self.copy() 179 if not self.pending_hints: 180 return df 181 expression = df.expression 182 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 183 for hint in df.pending_partition_hints: 184 hint_expression.append("expressions", hint) 185 df.pending_hints.remove(hint) 186 187 join_aliases = { 188 join_table.alias_or_name 189 for join_table in get_tables_from_expression_with_join(expression) 190 } 191 if join_aliases: 192 for hint in df.pending_join_hints: 193 for sequence_id_expression in hint.expressions: 194 sequence_id_or_name = sequence_id_expression.alias_or_name 195 sequence_ids_to_match = [sequence_id_or_name] 196 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 197 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 198 sequence_id_or_name 199 ] 200 matching_ctes = [ 201 cte 202 for cte in reversed(expression.ctes) 203 if cte.args["sequence_id"] in sequence_ids_to_match 204 ] 205 for matching_cte in matching_ctes: 206 if matching_cte.alias_or_name in join_aliases: 207 sequence_id_expression.set("this", matching_cte.args["alias"].this) 208 df.pending_hints.remove(hint) 209 break 210 hint_expression.append("expressions", hint) 211 if hint_expression.expressions: 212 expression.set("hint", hint_expression) 213 return df 214 215 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 216 hint_name = hint_name.upper() 217 hint_expression = ( 218 exp.JoinHint( 219 this=hint_name, 220 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 221 ) 222 if hint_name in JOIN_HINTS 223 else exp.Anonymous( 224 this=hint_name, expressions=[parameter.expression for parameter in args] 225 ) 226 ) 227 new_df = self.copy() 228 new_df.pending_hints.append(hint_expression) 229 return new_df 230 231 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 232 other_df = other._convert_leaf_to_cte() 233 base_expression = self.expression.copy() 234 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 235 all_ctes = base_expression.ctes 236 other_df.expression.set("with", None) 237 base_expression.set("with", None) 238 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 239 operation.set("with", exp.With(expressions=all_ctes)) 240 return self.copy(expression=operation)._convert_leaf_to_cte() 241 242 def _cache(self, storage_level: str): 243 df = self._convert_leaf_to_cte() 244 df.expression.ctes[-1].set("cache_storage_level", storage_level) 245 return df 246 247 @classmethod 248 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 249 expression = expression.copy() 250 with_expression = expression.args.get("with") 251 if with_expression: 252 existing_ctes = with_expression.expressions 253 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 254 for cte in ctes: 255 if cte.alias_or_name not in existsing_cte_names: 256 existing_ctes.append(cte) 257 else: 258 existing_ctes = ctes 259 expression.set("with", exp.With(expressions=existing_ctes)) 260 return expression 261 262 @classmethod 263 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 264 expression = item.expression if isinstance(item, DataFrame) else item 265 return [Column(x) for x in expression.find(exp.Select).expressions] 266 267 @classmethod 268 def _create_hash_from_expression(cls, expression: exp.Select): 269 value = expression.sql(dialect="spark").encode("utf-8") 270 return f"t{zlib.crc32(value)}"[:6] 271 272 def _get_select_expressions( 273 self, 274 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 275 select_expressions: t.List[ 276 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 277 ] = [] 278 main_select_ctes: t.List[exp.CTE] = [] 279 for cte in self.expression.ctes: 280 cache_storage_level = cte.args.get("cache_storage_level") 281 if cache_storage_level: 282 select_expression = cte.this.copy() 283 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 284 select_expression.set("cte_alias_name", cte.alias_or_name) 285 select_expression.set("cache_storage_level", cache_storage_level) 286 select_expressions.append((exp.Cache, select_expression)) 287 else: 288 main_select_ctes.append(cte) 289 main_select = self.expression.copy() 290 if main_select_ctes: 291 main_select.set("with", exp.With(expressions=main_select_ctes)) 292 expression_select_pair = (type(self.output_expression_container), main_select) 293 select_expressions.append(expression_select_pair) # type: ignore 294 return select_expressions 295 296 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 297 df = self._resolve_pending_hints() 298 select_expressions = df._get_select_expressions() 299 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 300 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 301 for expression_type, select_expression in select_expressions: 302 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 303 if optimize: 304 select_expression = optimize_func(select_expression) 305 select_expression = df._replace_cte_names_with_hashes(select_expression) 306 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 307 if expression_type == exp.Cache: 308 cache_table_name = df._create_hash_from_expression(select_expression) 309 cache_table = exp.to_table(cache_table_name) 310 original_alias_name = select_expression.args["cte_alias_name"] 311 312 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 313 cache_table_name 314 ) 315 sqlglot.schema.add_table( 316 cache_table_name, 317 { 318 expression.alias_or_name: expression.type.sql("spark") 319 for expression in select_expression.expressions 320 }, 321 ) 322 cache_storage_level = select_expression.args["cache_storage_level"] 323 options = [ 324 exp.Literal.string("storageLevel"), 325 exp.Literal.string(cache_storage_level), 326 ] 327 expression = exp.Cache( 328 this=cache_table, expression=select_expression, lazy=True, options=options 329 ) 330 # We will drop the "view" if it exists before running the cache table 331 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 332 elif expression_type == exp.Create: 333 expression = df.output_expression_container.copy() 334 expression.set("expression", select_expression) 335 elif expression_type == exp.Insert: 336 expression = df.output_expression_container.copy() 337 select_without_ctes = select_expression.copy() 338 select_without_ctes.set("with", None) 339 expression.set("expression", select_without_ctes) 340 if select_expression.ctes: 341 expression.set("with", exp.With(expressions=select_expression.ctes)) 342 elif expression_type == exp.Select: 343 expression = select_expression 344 else: 345 raise ValueError(f"Invalid expression type: {expression_type}") 346 output_expressions.append(expression) 347 348 return [ 349 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 350 ] 351 352 def copy(self, **kwargs) -> DataFrame: 353 return DataFrame(**object_to_dict(self, **kwargs)) 354 355 @operation(Operation.SELECT) 356 def select(self, *cols, **kwargs) -> DataFrame: 357 cols = self._ensure_and_normalize_cols(cols) 358 kwargs["append"] = kwargs.get("append", False) 359 if self.expression.args.get("joins"): 360 ambiguous_cols = [col for col in cols if not col.column_expression.table] 361 if ambiguous_cols: 362 join_table_identifiers = [ 363 x.this for x in get_tables_from_expression_with_join(self.expression) 364 ] 365 cte_names_in_join = [x.this for x in join_table_identifiers] 366 for ambiguous_col in ambiguous_cols: 367 ctes_with_column = [ 368 cte 369 for cte in self.expression.ctes 370 if cte.alias_or_name in cte_names_in_join 371 and ambiguous_col.alias_or_name in cte.this.named_selects 372 ] 373 # If the select column does not specify a table and there is a join 374 # then we assume they are referring to the left table 375 if len(ctes_with_column) > 1: 376 table_identifier = self.expression.args["from"].args["expressions"][0].this 377 else: 378 table_identifier = ctes_with_column[0].args["alias"].this 379 ambiguous_col.expression.set("table", table_identifier) 380 expression = self.expression.select(*[x.expression for x in cols], **kwargs) 381 qualify_columns(expression, sqlglot.schema) 382 return self.copy(expression=expression, **kwargs) 383 384 @operation(Operation.NO_OP) 385 def alias(self, name: str, **kwargs) -> DataFrame: 386 new_sequence_id = self.spark._random_sequence_id 387 df = self.copy() 388 for join_hint in df.pending_join_hints: 389 for expression in join_hint.expressions: 390 if expression.alias_or_name == self.sequence_id: 391 expression.set("this", Column.ensure_col(new_sequence_id).expression) 392 df.spark._add_alias_to_mapping(name, new_sequence_id) 393 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 394 395 @operation(Operation.WHERE) 396 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 397 col = self._ensure_and_normalize_col(column) 398 return self.copy(expression=self.expression.where(col.expression)) 399 400 filter = where 401 402 @operation(Operation.GROUP_BY) 403 def groupBy(self, *cols, **kwargs) -> GroupedData: 404 columns = self._ensure_and_normalize_cols(cols) 405 return GroupedData(self, columns, self.last_op) 406 407 @operation(Operation.SELECT) 408 def agg(self, *exprs, **kwargs) -> DataFrame: 409 cols = self._ensure_and_normalize_cols(exprs) 410 return self.groupBy().agg(*cols) 411 412 @operation(Operation.FROM) 413 def join( 414 self, 415 other_df: DataFrame, 416 on: t.Union[str, t.List[str], Column, t.List[Column]], 417 how: str = "inner", 418 **kwargs, 419 ) -> DataFrame: 420 other_df = other_df._convert_leaf_to_cte() 421 pre_join_self_latest_cte_name = self.latest_cte_name 422 columns = self._ensure_and_normalize_cols(on) 423 join_type = how.replace("_", " ") 424 if isinstance(columns[0].expression, exp.Column): 425 join_columns = [ 426 Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns 427 ] 428 join_clause = functools.reduce( 429 lambda x, y: x & y, 430 [ 431 col.copy().set_table_name(pre_join_self_latest_cte_name) 432 == col.copy().set_table_name(other_df.latest_cte_name) 433 for col in columns 434 ], 435 ) 436 else: 437 if len(columns) > 1: 438 columns = [functools.reduce(lambda x, y: x & y, columns)] 439 join_clause = columns[0] 440 join_columns = [ 441 Column(x).set_table_name(pre_join_self_latest_cte_name) 442 if i % 2 == 0 443 else Column(x).set_table_name(other_df.latest_cte_name) 444 for i, x in enumerate(join_clause.expression.find_all(exp.Column)) 445 ] 446 self_columns = [ 447 column.set_table_name(pre_join_self_latest_cte_name, copy=True) 448 for column in self._get_outer_select_columns(self) 449 ] 450 other_columns = [ 451 column.set_table_name(other_df.latest_cte_name, copy=True) 452 for column in self._get_outer_select_columns(other_df) 453 ] 454 column_value_mapping = { 455 column.alias_or_name 456 if not isinstance(column.expression.this, exp.Star) 457 else column.sql(): column 458 for column in other_columns + self_columns + join_columns 459 } 460 all_columns = [ 461 column_value_mapping[name] 462 for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} 463 ] 464 new_df = self.copy( 465 expression=self.expression.join( 466 other_df.latest_cte_name, on=join_clause.expression, join_type=join_type 467 ) 468 ) 469 new_df.expression = new_df._add_ctes_to_expression( 470 new_df.expression, other_df.expression.ctes 471 ) 472 new_df.pending_hints.extend(other_df.pending_hints) 473 new_df = new_df.select.__wrapped__(new_df, *all_columns) 474 return new_df 475 476 @operation(Operation.ORDER_BY) 477 def orderBy( 478 self, 479 *cols: t.Union[str, Column], 480 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 481 ) -> DataFrame: 482 """ 483 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 484 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 485 is unlikely to come up. 486 """ 487 columns = self._ensure_and_normalize_cols(cols) 488 pre_ordered_col_indexes = [ 489 x 490 for x in [ 491 i if isinstance(col.expression, exp.Ordered) else None 492 for i, col in enumerate(columns) 493 ] 494 if x is not None 495 ] 496 if ascending is None: 497 ascending = [True] * len(columns) 498 elif not isinstance(ascending, list): 499 ascending = [ascending] * len(columns) 500 ascending = [bool(x) for i, x in enumerate(ascending)] 501 assert len(columns) == len( 502 ascending 503 ), "The length of items in ascending must equal the number of columns provided" 504 col_and_ascending = list(zip(columns, ascending)) 505 order_by_columns = [ 506 exp.Ordered(this=col.expression, desc=not asc) 507 if i not in pre_ordered_col_indexes 508 else columns[i].column_expression 509 for i, (col, asc) in enumerate(col_and_ascending) 510 ] 511 return self.copy(expression=self.expression.order_by(*order_by_columns)) 512 513 sort = orderBy 514 515 @operation(Operation.FROM) 516 def union(self, other: DataFrame) -> DataFrame: 517 return self._set_operation(exp.Union, other, False) 518 519 unionAll = union 520 521 @operation(Operation.FROM) 522 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 523 l_columns = self.columns 524 r_columns = other.columns 525 if not allowMissingColumns: 526 l_expressions = l_columns 527 r_expressions = l_columns 528 else: 529 l_expressions = [] 530 r_expressions = [] 531 r_columns_unused = copy(r_columns) 532 for l_column in l_columns: 533 l_expressions.append(l_column) 534 if l_column in r_columns: 535 r_expressions.append(l_column) 536 r_columns_unused.remove(l_column) 537 else: 538 r_expressions.append(exp.alias_(exp.Null(), l_column)) 539 for r_column in r_columns_unused: 540 l_expressions.append(exp.alias_(exp.Null(), r_column)) 541 r_expressions.append(r_column) 542 r_df = ( 543 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 544 ) 545 l_df = self.copy() 546 if allowMissingColumns: 547 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 548 return l_df._set_operation(exp.Union, r_df, False) 549 550 @operation(Operation.FROM) 551 def intersect(self, other: DataFrame) -> DataFrame: 552 return self._set_operation(exp.Intersect, other, True) 553 554 @operation(Operation.FROM) 555 def intersectAll(self, other: DataFrame) -> DataFrame: 556 return self._set_operation(exp.Intersect, other, False) 557 558 @operation(Operation.FROM) 559 def exceptAll(self, other: DataFrame) -> DataFrame: 560 return self._set_operation(exp.Except, other, False) 561 562 @operation(Operation.SELECT) 563 def distinct(self) -> DataFrame: 564 return self.copy(expression=self.expression.distinct()) 565 566 @operation(Operation.SELECT) 567 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 568 if not subset: 569 return self.distinct() 570 column_names = ensure_list(subset) 571 window = Window.partitionBy(*column_names).orderBy(*column_names) 572 return ( 573 self.copy() 574 .withColumn("row_num", F.row_number().over(window)) 575 .where(F.col("row_num") == F.lit(1)) 576 .drop("row_num") 577 ) 578 579 @operation(Operation.FROM) 580 def dropna( 581 self, 582 how: str = "any", 583 thresh: t.Optional[int] = None, 584 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 585 ) -> DataFrame: 586 minimum_non_null = thresh or 0 # will be determined later if thresh is null 587 new_df = self.copy() 588 all_columns = self._get_outer_select_columns(new_df.expression) 589 if subset: 590 null_check_columns = self._ensure_and_normalize_cols(subset) 591 else: 592 null_check_columns = all_columns 593 if thresh is None: 594 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 595 else: 596 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 597 if minimum_num_nulls > len(null_check_columns): 598 raise RuntimeError( 599 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 600 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 601 ) 602 if_null_checks = [ 603 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 604 ] 605 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 606 num_nulls = nulls_added_together.alias("num_nulls") 607 new_df = new_df.select(num_nulls, append=True) 608 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 609 final_df = filtered_df.select(*all_columns) 610 return final_df 611 612 @operation(Operation.FROM) 613 def fillna( 614 self, 615 value: t.Union[ColumnLiterals], 616 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 617 ) -> DataFrame: 618 """ 619 Functionality Difference: If you provide a value to replace a null and that type conflicts 620 with the type of the column then PySpark will just ignore your replacement. 621 This will try to cast them to be the same in some cases. So they won't always match. 622 Best to not mix types so make sure replacement is the same type as the column 623 624 Possibility for improvement: Use `typeof` function to get the type of the column 625 and check if it matches the type of the value provided. If not then make it null. 626 """ 627 from sqlglot.dataframe.sql.functions import lit 628 629 values = None 630 columns = None 631 new_df = self.copy() 632 all_columns = self._get_outer_select_columns(new_df.expression) 633 all_column_mapping = {column.alias_or_name: column for column in all_columns} 634 if isinstance(value, dict): 635 values = list(value.values()) 636 columns = self._ensure_and_normalize_cols(list(value)) 637 if not columns: 638 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 639 if not values: 640 values = [value] * len(columns) 641 value_columns = [lit(value) for value in values] 642 643 null_replacement_mapping = { 644 column.alias_or_name: ( 645 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 646 ) 647 for column, value in zip(columns, value_columns) 648 } 649 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 650 null_replacement_columns = [ 651 null_replacement_mapping[column.alias_or_name] for column in all_columns 652 ] 653 new_df = new_df.select(*null_replacement_columns) 654 return new_df 655 656 @operation(Operation.FROM) 657 def replace( 658 self, 659 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 660 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 661 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 662 ) -> DataFrame: 663 from sqlglot.dataframe.sql.functions import lit 664 665 old_values = None 666 new_df = self.copy() 667 all_columns = self._get_outer_select_columns(new_df.expression) 668 all_column_mapping = {column.alias_or_name: column for column in all_columns} 669 670 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 671 if isinstance(to_replace, dict): 672 old_values = list(to_replace) 673 new_values = list(to_replace.values()) 674 elif not old_values and isinstance(to_replace, list): 675 assert isinstance(value, list), "value must be a list since the replacements are a list" 676 assert len(to_replace) == len( 677 value 678 ), "the replacements and values must be the same length" 679 old_values = to_replace 680 new_values = value 681 else: 682 old_values = [to_replace] * len(columns) 683 new_values = [value] * len(columns) 684 old_values = [lit(value) for value in old_values] 685 new_values = [lit(value) for value in new_values] 686 687 replacement_mapping = {} 688 for column in columns: 689 expression = Column(None) 690 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 691 if i == 0: 692 expression = F.when(column == old_value, new_value) 693 else: 694 expression = expression.when(column == old_value, new_value) # type: ignore 695 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 696 column.expression.alias_or_name 697 ) 698 699 replacement_mapping = {**all_column_mapping, **replacement_mapping} 700 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 701 new_df = new_df.select(*replacement_columns) 702 return new_df 703 704 @operation(Operation.SELECT) 705 def withColumn(self, colName: str, col: Column) -> DataFrame: 706 col = self._ensure_and_normalize_col(col) 707 existing_col_names = self.expression.named_selects 708 existing_col_index = ( 709 existing_col_names.index(colName) if colName in existing_col_names else None 710 ) 711 if existing_col_index: 712 expression = self.expression.copy() 713 expression.expressions[existing_col_index] = col.expression 714 return self.copy(expression=expression) 715 return self.copy().select(col.alias(colName), append=True) 716 717 @operation(Operation.SELECT) 718 def withColumnRenamed(self, existing: str, new: str): 719 expression = self.expression.copy() 720 existing_columns = [ 721 expression 722 for expression in expression.expressions 723 if expression.alias_or_name == existing 724 ] 725 if not existing_columns: 726 raise ValueError("Tried to rename a column that doesn't exist") 727 for existing_column in existing_columns: 728 if isinstance(existing_column, exp.Column): 729 existing_column.replace(exp.alias_(existing_column.copy(), new)) 730 else: 731 existing_column.set("alias", exp.to_identifier(new)) 732 return self.copy(expression=expression) 733 734 @operation(Operation.SELECT) 735 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 736 all_columns = self._get_outer_select_columns(self.expression) 737 drop_cols = self._ensure_and_normalize_cols(cols) 738 new_columns = [ 739 col 740 for col in all_columns 741 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 742 ] 743 return self.copy().select(*new_columns, append=False) 744 745 @operation(Operation.LIMIT) 746 def limit(self, num: int) -> DataFrame: 747 return self.copy(expression=self.expression.limit(num)) 748 749 @operation(Operation.NO_OP) 750 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 751 parameter_list = ensure_list(parameters) 752 parameter_columns = ( 753 self._ensure_list_of_columns(parameter_list) 754 if parameters 755 else Column.ensure_cols([self.sequence_id]) 756 ) 757 return self._hint(name, parameter_columns) 758 759 @operation(Operation.NO_OP) 760 def repartition( 761 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 762 ) -> DataFrame: 763 num_partition_cols = self._ensure_list_of_columns(numPartitions) 764 columns = self._ensure_and_normalize_cols(cols) 765 args = num_partition_cols + columns 766 return self._hint("repartition", args) 767 768 @operation(Operation.NO_OP) 769 def coalesce(self, numPartitions: int) -> DataFrame: 770 num_partitions = Column.ensure_cols([numPartitions]) 771 return self._hint("coalesce", num_partitions) 772 773 @operation(Operation.NO_OP) 774 def cache(self) -> DataFrame: 775 return self._cache(storage_level="MEMORY_AND_DISK") 776 777 @operation(Operation.NO_OP) 778 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 779 """ 780 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 781 """ 782 return self._cache(storageLevel)
47 def __init__( 48 self, 49 spark: SparkSession, 50 expression: exp.Select, 51 branch_id: t.Optional[str] = None, 52 sequence_id: t.Optional[str] = None, 53 last_op: Operation = Operation.INIT, 54 pending_hints: t.Optional[t.List[exp.Expression]] = None, 55 output_expression_container: t.Optional[OutputExpressionContainer] = None, 56 **kwargs, 57 ): 58 self.spark = spark 59 self.expression = expression 60 self.branch_id = branch_id or self.spark._random_branch_id 61 self.sequence_id = sequence_id or self.spark._random_sequence_id 62 self.last_op = last_op 63 self.pending_hints = pending_hints or [] 64 self.output_expression_container = output_expression_container or exp.Select()
296 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 297 df = self._resolve_pending_hints() 298 select_expressions = df._get_select_expressions() 299 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 300 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 301 for expression_type, select_expression in select_expressions: 302 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 303 if optimize: 304 select_expression = optimize_func(select_expression) 305 select_expression = df._replace_cte_names_with_hashes(select_expression) 306 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 307 if expression_type == exp.Cache: 308 cache_table_name = df._create_hash_from_expression(select_expression) 309 cache_table = exp.to_table(cache_table_name) 310 original_alias_name = select_expression.args["cte_alias_name"] 311 312 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 313 cache_table_name 314 ) 315 sqlglot.schema.add_table( 316 cache_table_name, 317 { 318 expression.alias_or_name: expression.type.sql("spark") 319 for expression in select_expression.expressions 320 }, 321 ) 322 cache_storage_level = select_expression.args["cache_storage_level"] 323 options = [ 324 exp.Literal.string("storageLevel"), 325 exp.Literal.string(cache_storage_level), 326 ] 327 expression = exp.Cache( 328 this=cache_table, expression=select_expression, lazy=True, options=options 329 ) 330 # We will drop the "view" if it exists before running the cache table 331 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 332 elif expression_type == exp.Create: 333 expression = df.output_expression_container.copy() 334 expression.set("expression", select_expression) 335 elif expression_type == exp.Insert: 336 expression = df.output_expression_container.copy() 337 select_without_ctes = select_expression.copy() 338 select_without_ctes.set("with", None) 339 expression.set("expression", select_without_ctes) 340 if select_expression.ctes: 341 expression.set("with", exp.With(expressions=select_expression.ctes)) 342 elif expression_type == exp.Select: 343 expression = select_expression 344 else: 345 raise ValueError(f"Invalid expression type: {expression_type}") 346 output_expressions.append(expression) 347 348 return [ 349 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 350 ]
355 @operation(Operation.SELECT) 356 def select(self, *cols, **kwargs) -> DataFrame: 357 cols = self._ensure_and_normalize_cols(cols) 358 kwargs["append"] = kwargs.get("append", False) 359 if self.expression.args.get("joins"): 360 ambiguous_cols = [col for col in cols if not col.column_expression.table] 361 if ambiguous_cols: 362 join_table_identifiers = [ 363 x.this for x in get_tables_from_expression_with_join(self.expression) 364 ] 365 cte_names_in_join = [x.this for x in join_table_identifiers] 366 for ambiguous_col in ambiguous_cols: 367 ctes_with_column = [ 368 cte 369 for cte in self.expression.ctes 370 if cte.alias_or_name in cte_names_in_join 371 and ambiguous_col.alias_or_name in cte.this.named_selects 372 ] 373 # If the select column does not specify a table and there is a join 374 # then we assume they are referring to the left table 375 if len(ctes_with_column) > 1: 376 table_identifier = self.expression.args["from"].args["expressions"][0].this 377 else: 378 table_identifier = ctes_with_column[0].args["alias"].this 379 ambiguous_col.expression.set("table", table_identifier) 380 expression = self.expression.select(*[x.expression for x in cols], **kwargs) 381 qualify_columns(expression, sqlglot.schema) 382 return self.copy(expression=expression, **kwargs)
384 @operation(Operation.NO_OP) 385 def alias(self, name: str, **kwargs) -> DataFrame: 386 new_sequence_id = self.spark._random_sequence_id 387 df = self.copy() 388 for join_hint in df.pending_join_hints: 389 for expression in join_hint.expressions: 390 if expression.alias_or_name == self.sequence_id: 391 expression.set("this", Column.ensure_col(new_sequence_id).expression) 392 df.spark._add_alias_to_mapping(name, new_sequence_id) 393 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
412 @operation(Operation.FROM) 413 def join( 414 self, 415 other_df: DataFrame, 416 on: t.Union[str, t.List[str], Column, t.List[Column]], 417 how: str = "inner", 418 **kwargs, 419 ) -> DataFrame: 420 other_df = other_df._convert_leaf_to_cte() 421 pre_join_self_latest_cte_name = self.latest_cte_name 422 columns = self._ensure_and_normalize_cols(on) 423 join_type = how.replace("_", " ") 424 if isinstance(columns[0].expression, exp.Column): 425 join_columns = [ 426 Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns 427 ] 428 join_clause = functools.reduce( 429 lambda x, y: x & y, 430 [ 431 col.copy().set_table_name(pre_join_self_latest_cte_name) 432 == col.copy().set_table_name(other_df.latest_cte_name) 433 for col in columns 434 ], 435 ) 436 else: 437 if len(columns) > 1: 438 columns = [functools.reduce(lambda x, y: x & y, columns)] 439 join_clause = columns[0] 440 join_columns = [ 441 Column(x).set_table_name(pre_join_self_latest_cte_name) 442 if i % 2 == 0 443 else Column(x).set_table_name(other_df.latest_cte_name) 444 for i, x in enumerate(join_clause.expression.find_all(exp.Column)) 445 ] 446 self_columns = [ 447 column.set_table_name(pre_join_self_latest_cte_name, copy=True) 448 for column in self._get_outer_select_columns(self) 449 ] 450 other_columns = [ 451 column.set_table_name(other_df.latest_cte_name, copy=True) 452 for column in self._get_outer_select_columns(other_df) 453 ] 454 column_value_mapping = { 455 column.alias_or_name 456 if not isinstance(column.expression.this, exp.Star) 457 else column.sql(): column 458 for column in other_columns + self_columns + join_columns 459 } 460 all_columns = [ 461 column_value_mapping[name] 462 for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} 463 ] 464 new_df = self.copy( 465 expression=self.expression.join( 466 other_df.latest_cte_name, on=join_clause.expression, join_type=join_type 467 ) 468 ) 469 new_df.expression = new_df._add_ctes_to_expression( 470 new_df.expression, other_df.expression.ctes 471 ) 472 new_df.pending_hints.extend(other_df.pending_hints) 473 new_df = new_df.select.__wrapped__(new_df, *all_columns) 474 return new_df
476 @operation(Operation.ORDER_BY) 477 def orderBy( 478 self, 479 *cols: t.Union[str, Column], 480 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 481 ) -> DataFrame: 482 """ 483 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 484 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 485 is unlikely to come up. 486 """ 487 columns = self._ensure_and_normalize_cols(cols) 488 pre_ordered_col_indexes = [ 489 x 490 for x in [ 491 i if isinstance(col.expression, exp.Ordered) else None 492 for i, col in enumerate(columns) 493 ] 494 if x is not None 495 ] 496 if ascending is None: 497 ascending = [True] * len(columns) 498 elif not isinstance(ascending, list): 499 ascending = [ascending] * len(columns) 500 ascending = [bool(x) for i, x in enumerate(ascending)] 501 assert len(columns) == len( 502 ascending 503 ), "The length of items in ascending must equal the number of columns provided" 504 col_and_ascending = list(zip(columns, ascending)) 505 order_by_columns = [ 506 exp.Ordered(this=col.expression, desc=not asc) 507 if i not in pre_ordered_col_indexes 508 else columns[i].column_expression 509 for i, (col, asc) in enumerate(col_and_ascending) 510 ] 511 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
476 @operation(Operation.ORDER_BY) 477 def orderBy( 478 self, 479 *cols: t.Union[str, Column], 480 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 481 ) -> DataFrame: 482 """ 483 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 484 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 485 is unlikely to come up. 486 """ 487 columns = self._ensure_and_normalize_cols(cols) 488 pre_ordered_col_indexes = [ 489 x 490 for x in [ 491 i if isinstance(col.expression, exp.Ordered) else None 492 for i, col in enumerate(columns) 493 ] 494 if x is not None 495 ] 496 if ascending is None: 497 ascending = [True] * len(columns) 498 elif not isinstance(ascending, list): 499 ascending = [ascending] * len(columns) 500 ascending = [bool(x) for i, x in enumerate(ascending)] 501 assert len(columns) == len( 502 ascending 503 ), "The length of items in ascending must equal the number of columns provided" 504 col_and_ascending = list(zip(columns, ascending)) 505 order_by_columns = [ 506 exp.Ordered(this=col.expression, desc=not asc) 507 if i not in pre_ordered_col_indexes 508 else columns[i].column_expression 509 for i, (col, asc) in enumerate(col_and_ascending) 510 ] 511 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
521 @operation(Operation.FROM) 522 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 523 l_columns = self.columns 524 r_columns = other.columns 525 if not allowMissingColumns: 526 l_expressions = l_columns 527 r_expressions = l_columns 528 else: 529 l_expressions = [] 530 r_expressions = [] 531 r_columns_unused = copy(r_columns) 532 for l_column in l_columns: 533 l_expressions.append(l_column) 534 if l_column in r_columns: 535 r_expressions.append(l_column) 536 r_columns_unused.remove(l_column) 537 else: 538 r_expressions.append(exp.alias_(exp.Null(), l_column)) 539 for r_column in r_columns_unused: 540 l_expressions.append(exp.alias_(exp.Null(), r_column)) 541 r_expressions.append(r_column) 542 r_df = ( 543 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 544 ) 545 l_df = self.copy() 546 if allowMissingColumns: 547 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 548 return l_df._set_operation(exp.Union, r_df, False)
566 @operation(Operation.SELECT) 567 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 568 if not subset: 569 return self.distinct() 570 column_names = ensure_list(subset) 571 window = Window.partitionBy(*column_names).orderBy(*column_names) 572 return ( 573 self.copy() 574 .withColumn("row_num", F.row_number().over(window)) 575 .where(F.col("row_num") == F.lit(1)) 576 .drop("row_num") 577 )
579 @operation(Operation.FROM) 580 def dropna( 581 self, 582 how: str = "any", 583 thresh: t.Optional[int] = None, 584 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 585 ) -> DataFrame: 586 minimum_non_null = thresh or 0 # will be determined later if thresh is null 587 new_df = self.copy() 588 all_columns = self._get_outer_select_columns(new_df.expression) 589 if subset: 590 null_check_columns = self._ensure_and_normalize_cols(subset) 591 else: 592 null_check_columns = all_columns 593 if thresh is None: 594 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 595 else: 596 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 597 if minimum_num_nulls > len(null_check_columns): 598 raise RuntimeError( 599 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 600 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 601 ) 602 if_null_checks = [ 603 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 604 ] 605 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 606 num_nulls = nulls_added_together.alias("num_nulls") 607 new_df = new_df.select(num_nulls, append=True) 608 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 609 final_df = filtered_df.select(*all_columns) 610 return final_df
612 @operation(Operation.FROM) 613 def fillna( 614 self, 615 value: t.Union[ColumnLiterals], 616 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 617 ) -> DataFrame: 618 """ 619 Functionality Difference: If you provide a value to replace a null and that type conflicts 620 with the type of the column then PySpark will just ignore your replacement. 621 This will try to cast them to be the same in some cases. So they won't always match. 622 Best to not mix types so make sure replacement is the same type as the column 623 624 Possibility for improvement: Use `typeof` function to get the type of the column 625 and check if it matches the type of the value provided. If not then make it null. 626 """ 627 from sqlglot.dataframe.sql.functions import lit 628 629 values = None 630 columns = None 631 new_df = self.copy() 632 all_columns = self._get_outer_select_columns(new_df.expression) 633 all_column_mapping = {column.alias_or_name: column for column in all_columns} 634 if isinstance(value, dict): 635 values = list(value.values()) 636 columns = self._ensure_and_normalize_cols(list(value)) 637 if not columns: 638 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 639 if not values: 640 values = [value] * len(columns) 641 value_columns = [lit(value) for value in values] 642 643 null_replacement_mapping = { 644 column.alias_or_name: ( 645 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 646 ) 647 for column, value in zip(columns, value_columns) 648 } 649 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 650 null_replacement_columns = [ 651 null_replacement_mapping[column.alias_or_name] for column in all_columns 652 ] 653 new_df = new_df.select(*null_replacement_columns) 654 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
656 @operation(Operation.FROM) 657 def replace( 658 self, 659 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 660 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 661 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 662 ) -> DataFrame: 663 from sqlglot.dataframe.sql.functions import lit 664 665 old_values = None 666 new_df = self.copy() 667 all_columns = self._get_outer_select_columns(new_df.expression) 668 all_column_mapping = {column.alias_or_name: column for column in all_columns} 669 670 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 671 if isinstance(to_replace, dict): 672 old_values = list(to_replace) 673 new_values = list(to_replace.values()) 674 elif not old_values and isinstance(to_replace, list): 675 assert isinstance(value, list), "value must be a list since the replacements are a list" 676 assert len(to_replace) == len( 677 value 678 ), "the replacements and values must be the same length" 679 old_values = to_replace 680 new_values = value 681 else: 682 old_values = [to_replace] * len(columns) 683 new_values = [value] * len(columns) 684 old_values = [lit(value) for value in old_values] 685 new_values = [lit(value) for value in new_values] 686 687 replacement_mapping = {} 688 for column in columns: 689 expression = Column(None) 690 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 691 if i == 0: 692 expression = F.when(column == old_value, new_value) 693 else: 694 expression = expression.when(column == old_value, new_value) # type: ignore 695 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 696 column.expression.alias_or_name 697 ) 698 699 replacement_mapping = {**all_column_mapping, **replacement_mapping} 700 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 701 new_df = new_df.select(*replacement_columns) 702 return new_df
704 @operation(Operation.SELECT) 705 def withColumn(self, colName: str, col: Column) -> DataFrame: 706 col = self._ensure_and_normalize_col(col) 707 existing_col_names = self.expression.named_selects 708 existing_col_index = ( 709 existing_col_names.index(colName) if colName in existing_col_names else None 710 ) 711 if existing_col_index: 712 expression = self.expression.copy() 713 expression.expressions[existing_col_index] = col.expression 714 return self.copy(expression=expression) 715 return self.copy().select(col.alias(colName), append=True)
717 @operation(Operation.SELECT) 718 def withColumnRenamed(self, existing: str, new: str): 719 expression = self.expression.copy() 720 existing_columns = [ 721 expression 722 for expression in expression.expressions 723 if expression.alias_or_name == existing 724 ] 725 if not existing_columns: 726 raise ValueError("Tried to rename a column that doesn't exist") 727 for existing_column in existing_columns: 728 if isinstance(existing_column, exp.Column): 729 existing_column.replace(exp.alias_(existing_column.copy(), new)) 730 else: 731 existing_column.set("alias", exp.to_identifier(new)) 732 return self.copy(expression=expression)
734 @operation(Operation.SELECT) 735 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 736 all_columns = self._get_outer_select_columns(self.expression) 737 drop_cols = self._ensure_and_normalize_cols(cols) 738 new_columns = [ 739 col 740 for col in all_columns 741 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 742 ] 743 return self.copy().select(*new_columns, append=False)
749 @operation(Operation.NO_OP) 750 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 751 parameter_list = ensure_list(parameters) 752 parameter_columns = ( 753 self._ensure_list_of_columns(parameter_list) 754 if parameters 755 else Column.ensure_cols([self.sequence_id]) 756 ) 757 return self._hint(name, parameter_columns)
759 @operation(Operation.NO_OP) 760 def repartition( 761 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 762 ) -> DataFrame: 763 num_partition_cols = self._ensure_list_of_columns(numPartitions) 764 columns = self._ensure_and_normalize_cols(cols) 765 args = num_partition_cols + columns 766 return self._hint("repartition", args)
777 @operation(Operation.NO_OP) 778 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 779 """ 780 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 781 """ 782 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 if isinstance(expression, Column): 19 expression = expression.expression # type: ignore 20 elif expression is None or not isinstance(expression, (str, exp.Expression)): 21 expression = self._lit(expression).expression # type: ignore 22 23 expression = sqlglot.maybe_parse(expression, dialect="spark") 24 if expression is None: 25 raise ValueError(f"Could not parse {expression}") 26 self.expression: exp.Expression = expression 27 28 def __repr__(self): 29 return repr(self.expression) 30 31 def __hash__(self): 32 return hash(self.expression) 33 34 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 35 return self.binary_op(exp.EQ, other) 36 37 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 38 return self.binary_op(exp.NEQ, other) 39 40 def __gt__(self, other: ColumnOrLiteral) -> Column: 41 return self.binary_op(exp.GT, other) 42 43 def __ge__(self, other: ColumnOrLiteral) -> Column: 44 return self.binary_op(exp.GTE, other) 45 46 def __lt__(self, other: ColumnOrLiteral) -> Column: 47 return self.binary_op(exp.LT, other) 48 49 def __le__(self, other: ColumnOrLiteral) -> Column: 50 return self.binary_op(exp.LTE, other) 51 52 def __and__(self, other: ColumnOrLiteral) -> Column: 53 return self.binary_op(exp.And, other) 54 55 def __or__(self, other: ColumnOrLiteral) -> Column: 56 return self.binary_op(exp.Or, other) 57 58 def __mod__(self, other: ColumnOrLiteral) -> Column: 59 return self.binary_op(exp.Mod, other) 60 61 def __add__(self, other: ColumnOrLiteral) -> Column: 62 return self.binary_op(exp.Add, other) 63 64 def __sub__(self, other: ColumnOrLiteral) -> Column: 65 return self.binary_op(exp.Sub, other) 66 67 def __mul__(self, other: ColumnOrLiteral) -> Column: 68 return self.binary_op(exp.Mul, other) 69 70 def __truediv__(self, other: ColumnOrLiteral) -> Column: 71 return self.binary_op(exp.Div, other) 72 73 def __div__(self, other: ColumnOrLiteral) -> Column: 74 return self.binary_op(exp.Div, other) 75 76 def __neg__(self) -> Column: 77 return self.unary_op(exp.Neg) 78 79 def __radd__(self, other: ColumnOrLiteral) -> Column: 80 return self.inverse_binary_op(exp.Add, other) 81 82 def __rsub__(self, other: ColumnOrLiteral) -> Column: 83 return self.inverse_binary_op(exp.Sub, other) 84 85 def __rmul__(self, other: ColumnOrLiteral) -> Column: 86 return self.inverse_binary_op(exp.Mul, other) 87 88 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 89 return self.inverse_binary_op(exp.Div, other) 90 91 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 92 return self.inverse_binary_op(exp.Div, other) 93 94 def __rmod__(self, other: ColumnOrLiteral) -> Column: 95 return self.inverse_binary_op(exp.Mod, other) 96 97 def __pow__(self, power: ColumnOrLiteral, modulo=None): 98 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 99 100 def __rpow__(self, power: ColumnOrLiteral): 101 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 102 103 def __invert__(self): 104 return self.unary_op(exp.Not) 105 106 def __rand__(self, other: ColumnOrLiteral) -> Column: 107 return self.inverse_binary_op(exp.And, other) 108 109 def __ror__(self, other: ColumnOrLiteral) -> Column: 110 return self.inverse_binary_op(exp.Or, other) 111 112 @classmethod 113 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 114 return cls(value) 115 116 @classmethod 117 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 118 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 119 120 @classmethod 121 def _lit(cls, value: ColumnOrLiteral) -> Column: 122 if isinstance(value, dict): 123 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 124 return cls(exp.Struct(expressions=columns)) 125 return cls(exp.convert(value)) 126 127 @classmethod 128 def invoke_anonymous_function( 129 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 130 ) -> Column: 131 columns = [] if column is None else [cls.ensure_col(column)] 132 column_args = [cls.ensure_col(arg) for arg in args] 133 expressions = [x.expression for x in columns + column_args] 134 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 135 return Column(new_expression) 136 137 @classmethod 138 def invoke_expression_over_column( 139 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 140 ) -> Column: 141 ensured_column = None if column is None else cls.ensure_col(column) 142 ensure_expression_values = { 143 k: [Column.ensure_col(x).expression for x in v] 144 if is_iterable(v) 145 else Column.ensure_col(v).expression 146 for k, v in kwargs.items() 147 } 148 new_expression = ( 149 callable_expression(**ensure_expression_values) 150 if ensured_column is None 151 else callable_expression( 152 this=ensured_column.column_expression, **ensure_expression_values 153 ) 154 ) 155 return Column(new_expression) 156 157 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 158 return Column( 159 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 160 ) 161 162 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 163 return Column( 164 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 165 ) 166 167 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 168 return Column(klass(this=self.column_expression, **kwargs)) 169 170 @property 171 def is_alias(self): 172 return isinstance(self.expression, exp.Alias) 173 174 @property 175 def is_column(self): 176 return isinstance(self.expression, exp.Column) 177 178 @property 179 def column_expression(self) -> exp.Column: 180 return self.expression.unalias() 181 182 @property 183 def alias_or_name(self) -> str: 184 return self.expression.alias_or_name 185 186 @classmethod 187 def ensure_literal(cls, value) -> Column: 188 from sqlglot.dataframe.sql.functions import lit 189 190 if isinstance(value, cls): 191 value = value.expression 192 if not isinstance(value, exp.Literal): 193 return lit(value) 194 return Column(value) 195 196 def copy(self) -> Column: 197 return Column(self.expression.copy()) 198 199 def set_table_name(self, table_name: str, copy=False) -> Column: 200 expression = self.expression.copy() if copy else self.expression 201 expression.set("table", exp.to_identifier(table_name)) 202 return Column(expression) 203 204 def sql(self, **kwargs) -> str: 205 return self.expression.sql(**{"dialect": "spark", **kwargs}) 206 207 def alias(self, name: str) -> Column: 208 new_expression = exp.alias_(self.column_expression, name) 209 return Column(new_expression) 210 211 def asc(self) -> Column: 212 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 213 return Column(new_expression) 214 215 def desc(self) -> Column: 216 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 217 return Column(new_expression) 218 219 asc_nulls_first = asc 220 221 def asc_nulls_last(self) -> Column: 222 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 223 return Column(new_expression) 224 225 def desc_nulls_first(self) -> Column: 226 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 227 return Column(new_expression) 228 229 desc_nulls_last = desc 230 231 def when(self, condition: Column, value: t.Any) -> Column: 232 from sqlglot.dataframe.sql.functions import when 233 234 column_with_if = when(condition, value) 235 if not isinstance(self.expression, exp.Case): 236 return column_with_if 237 new_column = self.copy() 238 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 239 return new_column 240 241 def otherwise(self, value: t.Any) -> Column: 242 from sqlglot.dataframe.sql.functions import lit 243 244 true_value = value if isinstance(value, Column) else lit(value) 245 new_column = self.copy() 246 new_column.expression.set("default", true_value.column_expression) 247 return new_column 248 249 def isNull(self) -> Column: 250 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 251 return Column(new_expression) 252 253 def isNotNull(self) -> Column: 254 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 255 return Column(new_expression) 256 257 def cast(self, dataType: t.Union[str, DataType]): 258 """ 259 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 260 Sqlglot doesn't currently replicate this class so it only accepts a string 261 """ 262 if isinstance(dataType, DataType): 263 dataType = dataType.simpleString() 264 return Column(exp.cast(self.column_expression, dataType, dialect="spark")) 265 266 def startswith(self, value: t.Union[str, Column]) -> Column: 267 value = self._lit(value) if not isinstance(value, Column) else value 268 return self.invoke_anonymous_function(self, "STARTSWITH", value) 269 270 def endswith(self, value: t.Union[str, Column]) -> Column: 271 value = self._lit(value) if not isinstance(value, Column) else value 272 return self.invoke_anonymous_function(self, "ENDSWITH", value) 273 274 def rlike(self, regexp: str) -> Column: 275 return self.invoke_expression_over_column( 276 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 277 ) 278 279 def like(self, other: str): 280 return self.invoke_expression_over_column( 281 self, exp.Like, expression=self._lit(other).expression 282 ) 283 284 def ilike(self, other: str): 285 return self.invoke_expression_over_column( 286 self, exp.ILike, expression=self._lit(other).expression 287 ) 288 289 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 290 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 291 length = self._lit(length) if not isinstance(length, Column) else length 292 return Column.invoke_expression_over_column( 293 self, exp.Substring, start=startPos.expression, length=length.expression 294 ) 295 296 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 297 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 298 expressions = [self._lit(x).expression for x in columns] 299 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 300 301 def between( 302 self, 303 lowerBound: t.Union[ColumnOrLiteral], 304 upperBound: t.Union[ColumnOrLiteral], 305 ) -> Column: 306 lower_bound_exp = ( 307 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 308 ) 309 upper_bound_exp = ( 310 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 311 ) 312 return Column( 313 exp.Between( 314 this=self.column_expression, 315 low=lower_bound_exp.expression, 316 high=upper_bound_exp.expression, 317 ) 318 ) 319 320 def over(self, window: WindowSpec) -> Column: 321 window_expression = window.expression.copy() 322 window_expression.set("this", self.column_expression) 323 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 if isinstance(expression, Column): 19 expression = expression.expression # type: ignore 20 elif expression is None or not isinstance(expression, (str, exp.Expression)): 21 expression = self._lit(expression).expression # type: ignore 22 23 expression = sqlglot.maybe_parse(expression, dialect="spark") 24 if expression is None: 25 raise ValueError(f"Could not parse {expression}") 26 self.expression: exp.Expression = expression
127 @classmethod 128 def invoke_anonymous_function( 129 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 130 ) -> Column: 131 columns = [] if column is None else [cls.ensure_col(column)] 132 column_args = [cls.ensure_col(arg) for arg in args] 133 expressions = [x.expression for x in columns + column_args] 134 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 135 return Column(new_expression)
137 @classmethod 138 def invoke_expression_over_column( 139 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 140 ) -> Column: 141 ensured_column = None if column is None else cls.ensure_col(column) 142 ensure_expression_values = { 143 k: [Column.ensure_col(x).expression for x in v] 144 if is_iterable(v) 145 else Column.ensure_col(v).expression 146 for k, v in kwargs.items() 147 } 148 new_expression = ( 149 callable_expression(**ensure_expression_values) 150 if ensured_column is None 151 else callable_expression( 152 this=ensured_column.column_expression, **ensure_expression_values 153 ) 154 ) 155 return Column(new_expression)
231 def when(self, condition: Column, value: t.Any) -> Column: 232 from sqlglot.dataframe.sql.functions import when 233 234 column_with_if = when(condition, value) 235 if not isinstance(self.expression, exp.Case): 236 return column_with_if 237 new_column = self.copy() 238 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 239 return new_column
257 def cast(self, dataType: t.Union[str, DataType]): 258 """ 259 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 260 Sqlglot doesn't currently replicate this class so it only accepts a string 261 """ 262 if isinstance(dataType, DataType): 263 dataType = dataType.simpleString() 264 return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
289 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 290 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 291 length = self._lit(length) if not isinstance(length, Column) else length 292 return Column.invoke_expression_over_column( 293 self, exp.Substring, start=startPos.expression, length=length.expression 294 )
296 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 297 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 298 expressions = [self._lit(x).expression for x in columns] 299 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
301 def between( 302 self, 303 lowerBound: t.Union[ColumnOrLiteral], 304 upperBound: t.Union[ColumnOrLiteral], 305 ) -> Column: 306 lower_bound_exp = ( 307 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 308 ) 309 upper_bound_exp = ( 310 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 311 ) 312 return Column( 313 exp.Between( 314 this=self.column_expression, 315 low=lower_bound_exp.expression, 316 high=upper_bound_exp.expression, 317 ) 318 )
785class DataFrameNaFunctions: 786 def __init__(self, df: DataFrame): 787 self.df = df 788 789 def drop( 790 self, 791 how: str = "any", 792 thresh: t.Optional[int] = None, 793 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 794 ) -> DataFrame: 795 return self.df.dropna(how=how, thresh=thresh, subset=subset) 796 797 def fill( 798 self, 799 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 800 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 801 ) -> DataFrame: 802 return self.df.fillna(value=value, subset=subset) 803 804 def replace( 805 self, 806 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 807 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 808 subset: t.Optional[t.Union[str, t.List[str]]] = None, 809 ) -> DataFrame: 810 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
804 def replace( 805 self, 806 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 807 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 808 subset: t.Optional[t.Union[str, t.List[str]]] = None, 809 ) -> DataFrame: 810 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 return self.expression.sql(dialect="spark", **kwargs) 53 54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec 64 65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec 77 78 def _calc_start_end( 79 self, start: int, end: int 80 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 81 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 82 "start_side": None, 83 "end_side": None, 84 } 85 if start == Window.currentRow: 86 kwargs["start"] = "CURRENT ROW" 87 else: 88 kwargs = { 89 **kwargs, 90 **{ 91 "start_side": "PRECEDING", 92 "start": "UNBOUNDED" 93 if start <= Window.unboundedPreceding 94 else F.lit(start).expression, 95 }, 96 } 97 if end == Window.currentRow: 98 kwargs["end"] = "CURRENT ROW" 99 else: 100 kwargs = { 101 **kwargs, 102 **{ 103 "end_side": "FOLLOWING", 104 "end": "UNBOUNDED" 105 if end >= Window.unboundedFollowing 106 else F.lit(end).expression, 107 }, 108 } 109 return kwargs 110 111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec 122 123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec
65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec
111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec
123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 22 sqlglot.schema.add_table(tableName) 23 return DataFrame( 24 self.spark, 25 exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), 26 )
29class DataFrameWriter: 30 def __init__( 31 self, 32 df: DataFrame, 33 spark: t.Optional[SparkSession] = None, 34 mode: t.Optional[str] = None, 35 by_name: bool = False, 36 ): 37 self._df = df 38 self._spark = spark or df.spark 39 self._mode = mode 40 self._by_name = by_name 41 42 def copy(self, **kwargs) -> DataFrameWriter: 43 return DataFrameWriter( 44 **{ 45 k[1:] if k.startswith("_") else k: v 46 for k, v in object_to_dict(self, **kwargs).items() 47 } 48 ) 49 50 def sql(self, **kwargs) -> t.List[str]: 51 return self._df.sql(**kwargs) 52 53 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 54 return self.copy(_mode=saveMode) 55 56 @property 57 def byName(self): 58 return self.copy(by_name=True) 59 60 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 61 output_expression_container = exp.Insert( 62 **{ 63 "this": exp.to_table(tableName), 64 "overwrite": overwrite, 65 } 66 ) 67 df = self._df.copy(output_expression_container=output_expression_container) 68 if self._by_name: 69 columns = sqlglot.schema.column_names(tableName, only_visible=True) 70 df = df._convert_leaf_to_cte().select(*columns) 71 72 return self.copy(_df=df) 73 74 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 75 if format is not None: 76 raise NotImplementedError("Providing Format in the save as table is not supported") 77 exists, replace, mode = None, None, mode or str(self._mode) 78 if mode == "append": 79 return self.insertInto(name) 80 if mode == "ignore": 81 exists = True 82 if mode == "overwrite": 83 replace = True 84 output_expression_container = exp.Create( 85 this=exp.to_table(name), 86 kind="TABLE", 87 exists=exists, 88 replace=replace, 89 ) 90 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
60 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 61 output_expression_container = exp.Insert( 62 **{ 63 "this": exp.to_table(tableName), 64 "overwrite": overwrite, 65 } 66 ) 67 df = self._df.copy(output_expression_container=output_expression_container) 68 if self._by_name: 69 columns = sqlglot.schema.column_names(tableName, only_visible=True) 70 df = df._convert_leaf_to_cte().select(*columns) 71 72 return self.copy(_df=df)
74 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 75 if format is not None: 76 raise NotImplementedError("Providing Format in the save as table is not supported") 77 exists, replace, mode = None, None, mode or str(self._mode) 78 if mode == "append": 79 return self.insertInto(name) 80 if mode == "ignore": 81 exists = True 82 if mode == "overwrite": 83 replace = True 84 output_expression_container = exp.Create( 85 this=exp.to_table(name), 86 kind="TABLE", 87 exists=exists, 88 replace=replace, 89 ) 90 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))