sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
20class SparkSession: 21 known_ids: t.ClassVar[t.Set[str]] = set() 22 known_branch_ids: t.ClassVar[t.Set[str]] = set() 23 known_sequence_ids: t.ClassVar[t.Set[str]] = set() 24 name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) 25 26 def __init__(self): 27 self.incrementing_id = 1 28 29 def __getattr__(self, name: str) -> SparkSession: 30 return self 31 32 def __call__(self, *args, **kwargs) -> SparkSession: 33 return self 34 35 @property 36 def read(self) -> DataFrameReader: 37 return DataFrameReader(self) 38 39 def table(self, tableName: str) -> DataFrame: 40 return self.read.table(tableName) 41 42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 this=exp.Values( 92 expressions=data_expressions, 93 alias=exp.TableAlias( 94 this=exp.to_identifier(self._auto_incrementing_name), 95 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 96 ), 97 ), 98 ), 99 } 100 101 sel_expression = exp.Select(**select_kwargs) 102 return DataFrame(self, sel_expression) 103 104 def sql(self, sqlQuery: str) -> DataFrame: 105 expression = sqlglot.parse_one(sqlQuery, read="spark") 106 if isinstance(expression, exp.Select): 107 df = DataFrame(self, expression) 108 df = df._convert_leaf_to_cte() 109 elif isinstance(expression, (exp.Create, exp.Insert)): 110 select_expression = expression.expression.copy() 111 if isinstance(expression, exp.Insert): 112 select_expression.set("with", expression.args.get("with")) 113 expression.set("with", None) 114 del expression.args["expression"] 115 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 116 df = df._convert_leaf_to_cte() 117 else: 118 raise ValueError( 119 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 120 ) 121 return df 122 123 @property 124 def _auto_incrementing_name(self) -> str: 125 name = f"a{self.incrementing_id}" 126 self.incrementing_id += 1 127 return name 128 129 @property 130 def _random_branch_id(self) -> str: 131 id = self._random_id 132 self.known_branch_ids.add(id) 133 return id 134 135 @property 136 def _random_sequence_id(self): 137 id = self._random_id 138 self.known_sequence_ids.add(id) 139 return id 140 141 @property 142 def _random_id(self) -> str: 143 id = "r" + uuid.uuid4().hex 144 self.known_ids.add(id) 145 return id 146 147 @property 148 def _join_hint_names(self) -> t.Set[str]: 149 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 150 151 def _add_alias_to_mapping(self, name: str, sequence_id: str): 152 self.name_to_sequence_id_mapping[name].append(sequence_id)
42 def createDataFrame( 43 self, 44 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 45 schema: t.Optional[SchemaInput] = None, 46 samplingRatio: t.Optional[float] = None, 47 verifySchema: bool = False, 48 ) -> DataFrame: 49 from sqlglot.dataframe.sql.dataframe import DataFrame 50 51 if samplingRatio is not None or verifySchema: 52 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 53 if schema is not None and ( 54 not isinstance(schema, (StructType, str, list)) 55 or (isinstance(schema, list) and not isinstance(schema[0], str)) 56 ): 57 raise NotImplementedError("Only schema of either list or string of list supported") 58 if not data: 59 raise ValueError("Must provide data to create into a DataFrame") 60 61 column_mapping: t.Dict[str, t.Optional[str]] 62 if schema is not None: 63 column_mapping = get_column_mapping_from_schema_input(schema) 64 elif isinstance(data[0], dict): 65 column_mapping = {col_name.strip(): None for col_name in data[0]} 66 else: 67 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 68 69 data_expressions = [ 70 exp.Tuple( 71 expressions=list( 72 map( 73 lambda x: F.lit(x).expression, 74 row if not isinstance(row, dict) else row.values(), 75 ) 76 ) 77 ) 78 for row in data 79 ] 80 81 sel_columns = [ 82 F.col(name).cast(data_type).alias(name).expression 83 if data_type is not None 84 else F.col(name).expression 85 for name, data_type in column_mapping.items() 86 ] 87 88 select_kwargs = { 89 "expressions": sel_columns, 90 "from": exp.From( 91 this=exp.Values( 92 expressions=data_expressions, 93 alias=exp.TableAlias( 94 this=exp.to_identifier(self._auto_incrementing_name), 95 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 96 ), 97 ), 98 ), 99 } 100 101 sel_expression = exp.Select(**select_kwargs) 102 return DataFrame(self, sel_expression)
104 def sql(self, sqlQuery: str) -> DataFrame: 105 expression = sqlglot.parse_one(sqlQuery, read="spark") 106 if isinstance(expression, exp.Select): 107 df = DataFrame(self, expression) 108 df = df._convert_leaf_to_cte() 109 elif isinstance(expression, (exp.Create, exp.Insert)): 110 select_expression = expression.expression.copy() 111 if isinstance(expression, exp.Insert): 112 select_expression.set("with", expression.args.get("with")) 113 expression.set("with", None) 114 del expression.args["expression"] 115 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 116 df = df._convert_leaf_to_cte() 117 else: 118 raise ValueError( 119 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 120 ) 121 return df
45class DataFrame: 46 def __init__( 47 self, 48 spark: SparkSession, 49 expression: exp.Select, 50 branch_id: t.Optional[str] = None, 51 sequence_id: t.Optional[str] = None, 52 last_op: Operation = Operation.INIT, 53 pending_hints: t.Optional[t.List[exp.Expression]] = None, 54 output_expression_container: t.Optional[OutputExpressionContainer] = None, 55 **kwargs, 56 ): 57 self.spark = spark 58 self.expression = expression 59 self.branch_id = branch_id or self.spark._random_branch_id 60 self.sequence_id = sequence_id or self.spark._random_sequence_id 61 self.last_op = last_op 62 self.pending_hints = pending_hints or [] 63 self.output_expression_container = output_expression_container or exp.Select() 64 65 def __getattr__(self, column_name: str) -> Column: 66 return self[column_name] 67 68 def __getitem__(self, column_name: str) -> Column: 69 column_name = f"{self.branch_id}.{column_name}" 70 return Column(column_name) 71 72 def __copy__(self): 73 return self.copy() 74 75 @property 76 def sparkSession(self): 77 return self.spark 78 79 @property 80 def write(self): 81 return DataFrameWriter(self) 82 83 @property 84 def latest_cte_name(self) -> str: 85 if not self.expression.ctes: 86 from_exp = self.expression.args["from"] 87 if from_exp.alias_or_name: 88 return from_exp.alias_or_name 89 table_alias = from_exp.find(exp.TableAlias) 90 if not table_alias: 91 raise RuntimeError( 92 f"Could not find an alias name for this expression: {self.expression}" 93 ) 94 return table_alias.alias_or_name 95 return self.expression.ctes[-1].alias 96 97 @property 98 def pending_join_hints(self): 99 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 100 101 @property 102 def pending_partition_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 104 105 @property 106 def columns(self) -> t.List[str]: 107 return self.expression.named_selects 108 109 @property 110 def na(self) -> DataFrameNaFunctions: 111 return DataFrameNaFunctions(self) 112 113 def _replace_cte_names_with_hashes(self, expression: exp.Select): 114 replacement_mapping = {} 115 for cte in expression.ctes: 116 old_name_id = cte.args["alias"].this 117 new_hashed_id = exp.to_identifier( 118 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 119 ) 120 replacement_mapping[old_name_id] = new_hashed_id 121 expression = expression.transform(replace_id_value, replacement_mapping) 122 return expression 123 124 def _create_cte_from_expression( 125 self, 126 expression: exp.Expression, 127 branch_id: t.Optional[str] = None, 128 sequence_id: t.Optional[str] = None, 129 **kwargs, 130 ) -> t.Tuple[exp.CTE, str]: 131 name = self._create_hash_from_expression(expression) 132 expression_to_cte = expression.copy() 133 expression_to_cte.set("with", None) 134 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 135 cte.set("branch_id", branch_id or self.branch_id) 136 cte.set("sequence_id", sequence_id or self.sequence_id) 137 return cte, name 138 139 @t.overload 140 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 141 ... 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 145 ... 146 147 def _ensure_list_of_columns(self, cols): 148 return Column.ensure_cols(ensure_list(cols)) 149 150 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 151 cols = self._ensure_list_of_columns(cols) 152 normalize(self.spark, expression or self.expression, cols) 153 return cols 154 155 def _ensure_and_normalize_col(self, col): 156 col = Column.ensure_col(col) 157 normalize(self.spark, self.expression, col) 158 return col 159 160 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 161 df = self._resolve_pending_hints() 162 sequence_id = sequence_id or df.sequence_id 163 expression = df.expression.copy() 164 cte_expression, cte_name = df._create_cte_from_expression( 165 expression=expression, sequence_id=sequence_id 166 ) 167 new_expression = df._add_ctes_to_expression( 168 exp.Select(), expression.ctes + [cte_expression] 169 ) 170 sel_columns = df._get_outer_select_columns(cte_expression) 171 new_expression = new_expression.from_(cte_name).select( 172 *[x.alias_or_name for x in sel_columns] 173 ) 174 return df.copy(expression=new_expression, sequence_id=sequence_id) 175 176 def _resolve_pending_hints(self) -> DataFrame: 177 df = self.copy() 178 if not self.pending_hints: 179 return df 180 expression = df.expression 181 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 182 for hint in df.pending_partition_hints: 183 hint_expression.append("expressions", hint) 184 df.pending_hints.remove(hint) 185 186 join_aliases = { 187 join_table.alias_or_name 188 for join_table in get_tables_from_expression_with_join(expression) 189 } 190 if join_aliases: 191 for hint in df.pending_join_hints: 192 for sequence_id_expression in hint.expressions: 193 sequence_id_or_name = sequence_id_expression.alias_or_name 194 sequence_ids_to_match = [sequence_id_or_name] 195 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 196 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 197 sequence_id_or_name 198 ] 199 matching_ctes = [ 200 cte 201 for cte in reversed(expression.ctes) 202 if cte.args["sequence_id"] in sequence_ids_to_match 203 ] 204 for matching_cte in matching_ctes: 205 if matching_cte.alias_or_name in join_aliases: 206 sequence_id_expression.set("this", matching_cte.args["alias"].this) 207 df.pending_hints.remove(hint) 208 break 209 hint_expression.append("expressions", hint) 210 if hint_expression.expressions: 211 expression.set("hint", hint_expression) 212 return df 213 214 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 215 hint_name = hint_name.upper() 216 hint_expression = ( 217 exp.JoinHint( 218 this=hint_name, 219 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 220 ) 221 if hint_name in JOIN_HINTS 222 else exp.Anonymous( 223 this=hint_name, expressions=[parameter.expression for parameter in args] 224 ) 225 ) 226 new_df = self.copy() 227 new_df.pending_hints.append(hint_expression) 228 return new_df 229 230 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 231 other_df = other._convert_leaf_to_cte() 232 base_expression = self.expression.copy() 233 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 234 all_ctes = base_expression.ctes 235 other_df.expression.set("with", None) 236 base_expression.set("with", None) 237 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 238 operation.set("with", exp.With(expressions=all_ctes)) 239 return self.copy(expression=operation)._convert_leaf_to_cte() 240 241 def _cache(self, storage_level: str): 242 df = self._convert_leaf_to_cte() 243 df.expression.ctes[-1].set("cache_storage_level", storage_level) 244 return df 245 246 @classmethod 247 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 248 expression = expression.copy() 249 with_expression = expression.args.get("with") 250 if with_expression: 251 existing_ctes = with_expression.expressions 252 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 253 for cte in ctes: 254 if cte.alias_or_name not in existsing_cte_names: 255 existing_ctes.append(cte) 256 else: 257 existing_ctes = ctes 258 expression.set("with", exp.With(expressions=existing_ctes)) 259 return expression 260 261 @classmethod 262 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 263 expression = item.expression if isinstance(item, DataFrame) else item 264 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 265 266 @classmethod 267 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 268 value = expression.sql(dialect="spark").encode("utf-8") 269 return f"t{zlib.crc32(value)}"[:6] 270 271 def _get_select_expressions( 272 self, 273 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 274 select_expressions: t.List[ 275 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 276 ] = [] 277 main_select_ctes: t.List[exp.CTE] = [] 278 for cte in self.expression.ctes: 279 cache_storage_level = cte.args.get("cache_storage_level") 280 if cache_storage_level: 281 select_expression = cte.this.copy() 282 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 283 select_expression.set("cte_alias_name", cte.alias_or_name) 284 select_expression.set("cache_storage_level", cache_storage_level) 285 select_expressions.append((exp.Cache, select_expression)) 286 else: 287 main_select_ctes.append(cte) 288 main_select = self.expression.copy() 289 if main_select_ctes: 290 main_select.set("with", exp.With(expressions=main_select_ctes)) 291 expression_select_pair = (type(self.output_expression_container), main_select) 292 select_expressions.append(expression_select_pair) # type: ignore 293 return select_expressions 294 295 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 296 df = self._resolve_pending_hints() 297 select_expressions = df._get_select_expressions() 298 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 299 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 300 for expression_type, select_expression in select_expressions: 301 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 302 if optimize: 303 select_expression = t.cast(exp.Select, optimize_func(select_expression)) 304 select_expression = df._replace_cte_names_with_hashes(select_expression) 305 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 306 if expression_type == exp.Cache: 307 cache_table_name = df._create_hash_from_expression(select_expression) 308 cache_table = exp.to_table(cache_table_name) 309 original_alias_name = select_expression.args["cte_alias_name"] 310 311 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 312 cache_table_name 313 ) 314 sqlglot.schema.add_table( 315 cache_table_name, 316 { 317 expression.alias_or_name: expression.type.sql("spark") 318 for expression in select_expression.expressions 319 }, 320 dialect="spark", 321 ) 322 cache_storage_level = select_expression.args["cache_storage_level"] 323 options = [ 324 exp.Literal.string("storageLevel"), 325 exp.Literal.string(cache_storage_level), 326 ] 327 expression = exp.Cache( 328 this=cache_table, expression=select_expression, lazy=True, options=options 329 ) 330 # We will drop the "view" if it exists before running the cache table 331 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 332 elif expression_type == exp.Create: 333 expression = df.output_expression_container.copy() 334 expression.set("expression", select_expression) 335 elif expression_type == exp.Insert: 336 expression = df.output_expression_container.copy() 337 select_without_ctes = select_expression.copy() 338 select_without_ctes.set("with", None) 339 expression.set("expression", select_without_ctes) 340 if select_expression.ctes: 341 expression.set("with", exp.With(expressions=select_expression.ctes)) 342 elif expression_type == exp.Select: 343 expression = select_expression 344 else: 345 raise ValueError(f"Invalid expression type: {expression_type}") 346 output_expressions.append(expression) 347 348 return [ 349 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 350 ] 351 352 def copy(self, **kwargs) -> DataFrame: 353 return DataFrame(**object_to_dict(self, **kwargs)) 354 355 @operation(Operation.SELECT) 356 def select(self, *cols, **kwargs) -> DataFrame: 357 cols = self._ensure_and_normalize_cols(cols) 358 kwargs["append"] = kwargs.get("append", False) 359 if self.expression.args.get("joins"): 360 ambiguous_cols = [ 361 col 362 for col in cols 363 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 364 ] 365 if ambiguous_cols: 366 join_table_identifiers = [ 367 x.this for x in get_tables_from_expression_with_join(self.expression) 368 ] 369 cte_names_in_join = [x.this for x in join_table_identifiers] 370 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 371 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 372 # of Spark. 373 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 374 for ambiguous_col in ambiguous_cols: 375 ctes_with_column = [ 376 cte 377 for cte in self.expression.ctes 378 if cte.alias_or_name in cte_names_in_join 379 and ambiguous_col.alias_or_name in cte.this.named_selects 380 ] 381 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 382 # use the same CTE we used before 383 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 384 if cte: 385 resolved_column_position[ambiguous_col] += 1 386 else: 387 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 388 ambiguous_col.expression.set("table", cte.alias_or_name) 389 return self.copy( 390 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 391 ) 392 393 @operation(Operation.NO_OP) 394 def alias(self, name: str, **kwargs) -> DataFrame: 395 new_sequence_id = self.spark._random_sequence_id 396 df = self.copy() 397 for join_hint in df.pending_join_hints: 398 for expression in join_hint.expressions: 399 if expression.alias_or_name == self.sequence_id: 400 expression.set("this", Column.ensure_col(new_sequence_id).expression) 401 df.spark._add_alias_to_mapping(name, new_sequence_id) 402 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 403 404 @operation(Operation.WHERE) 405 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 406 col = self._ensure_and_normalize_col(column) 407 return self.copy(expression=self.expression.where(col.expression)) 408 409 filter = where 410 411 @operation(Operation.GROUP_BY) 412 def groupBy(self, *cols, **kwargs) -> GroupedData: 413 columns = self._ensure_and_normalize_cols(cols) 414 return GroupedData(self, columns, self.last_op) 415 416 @operation(Operation.SELECT) 417 def agg(self, *exprs, **kwargs) -> DataFrame: 418 cols = self._ensure_and_normalize_cols(exprs) 419 return self.groupBy().agg(*cols) 420 421 @operation(Operation.FROM) 422 def join( 423 self, 424 other_df: DataFrame, 425 on: t.Union[str, t.List[str], Column, t.List[Column]], 426 how: str = "inner", 427 **kwargs, 428 ) -> DataFrame: 429 other_df = other_df._convert_leaf_to_cte() 430 join_columns = self._ensure_list_of_columns(on) 431 # We will determine actual "join on" expression later so we don't provide it at first 432 join_expression = self.expression.join( 433 other_df.latest_cte_name, join_type=how.replace("_", " ") 434 ) 435 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 436 self_columns = self._get_outer_select_columns(join_expression) 437 other_columns = self._get_outer_select_columns(other_df) 438 # Determines the join clause and select columns to be used passed on what type of columns were provided for 439 # the join. The columns returned changes based on how the on expression is provided. 440 if isinstance(join_columns[0].expression, exp.Column): 441 """ 442 Unique characteristics of join on column names only: 443 * The column names are put at the front of the select list 444 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 445 """ 446 table_names = [ 447 table.alias_or_name 448 for table in get_tables_from_expression_with_join(join_expression) 449 ] 450 potential_ctes = [ 451 cte 452 for cte in join_expression.ctes 453 if cte.alias_or_name in table_names 454 and cte.alias_or_name != other_df.latest_cte_name 455 ] 456 # Determine the table to reference for the left side of the join by checking each of the left side 457 # tables and see if they have the column being referenced. 458 join_column_pairs = [] 459 for join_column in join_columns: 460 num_matching_ctes = 0 461 for cte in potential_ctes: 462 if join_column.alias_or_name in cte.this.named_selects: 463 left_column = join_column.copy().set_table_name(cte.alias_or_name) 464 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 465 join_column_pairs.append((left_column, right_column)) 466 num_matching_ctes += 1 467 if num_matching_ctes > 1: 468 raise ValueError( 469 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 470 ) 471 elif num_matching_ctes == 0: 472 raise ValueError( 473 f"Column {join_column.alias_or_name} does not exist in any of the tables." 474 ) 475 join_clause = functools.reduce( 476 lambda x, y: x & y, 477 [left_column == right_column for left_column, right_column in join_column_pairs], 478 ) 479 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 480 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 481 select_column_names = [ 482 column.alias_or_name 483 if not isinstance(column.expression.this, exp.Star) 484 else column.sql() 485 for column in self_columns + other_columns 486 ] 487 select_column_names = [ 488 column_name 489 for column_name in select_column_names 490 if column_name not in join_column_names 491 ] 492 select_column_names = join_column_names + select_column_names 493 else: 494 """ 495 Unique characteristics of join on expressions: 496 * There is no deduplication of the results. 497 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 498 """ 499 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 500 if len(join_columns) > 1: 501 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 502 join_clause = join_columns[0] 503 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 504 505 # Update the on expression with the actual join clause to replace the dummy one from before 506 join_expression.args["joins"][-1].set("on", join_clause.expression) 507 new_df = self.copy(expression=join_expression) 508 new_df.pending_join_hints.extend(self.pending_join_hints) 509 new_df.pending_hints.extend(other_df.pending_hints) 510 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 511 return new_df 512 513 @operation(Operation.ORDER_BY) 514 def orderBy( 515 self, 516 *cols: t.Union[str, Column], 517 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 518 ) -> DataFrame: 519 """ 520 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 521 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 522 is unlikely to come up. 523 """ 524 columns = self._ensure_and_normalize_cols(cols) 525 pre_ordered_col_indexes = [ 526 x 527 for x in [ 528 i if isinstance(col.expression, exp.Ordered) else None 529 for i, col in enumerate(columns) 530 ] 531 if x is not None 532 ] 533 if ascending is None: 534 ascending = [True] * len(columns) 535 elif not isinstance(ascending, list): 536 ascending = [ascending] * len(columns) 537 ascending = [bool(x) for i, x in enumerate(ascending)] 538 assert len(columns) == len( 539 ascending 540 ), "The length of items in ascending must equal the number of columns provided" 541 col_and_ascending = list(zip(columns, ascending)) 542 order_by_columns = [ 543 exp.Ordered(this=col.expression, desc=not asc) 544 if i not in pre_ordered_col_indexes 545 else columns[i].column_expression 546 for i, (col, asc) in enumerate(col_and_ascending) 547 ] 548 return self.copy(expression=self.expression.order_by(*order_by_columns)) 549 550 sort = orderBy 551 552 @operation(Operation.FROM) 553 def union(self, other: DataFrame) -> DataFrame: 554 return self._set_operation(exp.Union, other, False) 555 556 unionAll = union 557 558 @operation(Operation.FROM) 559 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 560 l_columns = self.columns 561 r_columns = other.columns 562 if not allowMissingColumns: 563 l_expressions = l_columns 564 r_expressions = l_columns 565 else: 566 l_expressions = [] 567 r_expressions = [] 568 r_columns_unused = copy(r_columns) 569 for l_column in l_columns: 570 l_expressions.append(l_column) 571 if l_column in r_columns: 572 r_expressions.append(l_column) 573 r_columns_unused.remove(l_column) 574 else: 575 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 576 for r_column in r_columns_unused: 577 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 578 r_expressions.append(r_column) 579 r_df = ( 580 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 581 ) 582 l_df = self.copy() 583 if allowMissingColumns: 584 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 585 return l_df._set_operation(exp.Union, r_df, False) 586 587 @operation(Operation.FROM) 588 def intersect(self, other: DataFrame) -> DataFrame: 589 return self._set_operation(exp.Intersect, other, True) 590 591 @operation(Operation.FROM) 592 def intersectAll(self, other: DataFrame) -> DataFrame: 593 return self._set_operation(exp.Intersect, other, False) 594 595 @operation(Operation.FROM) 596 def exceptAll(self, other: DataFrame) -> DataFrame: 597 return self._set_operation(exp.Except, other, False) 598 599 @operation(Operation.SELECT) 600 def distinct(self) -> DataFrame: 601 return self.copy(expression=self.expression.distinct()) 602 603 @operation(Operation.SELECT) 604 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 605 if not subset: 606 return self.distinct() 607 column_names = ensure_list(subset) 608 window = Window.partitionBy(*column_names).orderBy(*column_names) 609 return ( 610 self.copy() 611 .withColumn("row_num", F.row_number().over(window)) 612 .where(F.col("row_num") == F.lit(1)) 613 .drop("row_num") 614 ) 615 616 @operation(Operation.FROM) 617 def dropna( 618 self, 619 how: str = "any", 620 thresh: t.Optional[int] = None, 621 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 622 ) -> DataFrame: 623 minimum_non_null = thresh or 0 # will be determined later if thresh is null 624 new_df = self.copy() 625 all_columns = self._get_outer_select_columns(new_df.expression) 626 if subset: 627 null_check_columns = self._ensure_and_normalize_cols(subset) 628 else: 629 null_check_columns = all_columns 630 if thresh is None: 631 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 632 else: 633 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 634 if minimum_num_nulls > len(null_check_columns): 635 raise RuntimeError( 636 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 637 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 638 ) 639 if_null_checks = [ 640 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 641 ] 642 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 643 num_nulls = nulls_added_together.alias("num_nulls") 644 new_df = new_df.select(num_nulls, append=True) 645 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 646 final_df = filtered_df.select(*all_columns) 647 return final_df 648 649 @operation(Operation.FROM) 650 def fillna( 651 self, 652 value: t.Union[ColumnLiterals], 653 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 654 ) -> DataFrame: 655 """ 656 Functionality Difference: If you provide a value to replace a null and that type conflicts 657 with the type of the column then PySpark will just ignore your replacement. 658 This will try to cast them to be the same in some cases. So they won't always match. 659 Best to not mix types so make sure replacement is the same type as the column 660 661 Possibility for improvement: Use `typeof` function to get the type of the column 662 and check if it matches the type of the value provided. If not then make it null. 663 """ 664 from sqlglot.dataframe.sql.functions import lit 665 666 values = None 667 columns = None 668 new_df = self.copy() 669 all_columns = self._get_outer_select_columns(new_df.expression) 670 all_column_mapping = {column.alias_or_name: column for column in all_columns} 671 if isinstance(value, dict): 672 values = list(value.values()) 673 columns = self._ensure_and_normalize_cols(list(value)) 674 if not columns: 675 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 676 if not values: 677 values = [value] * len(columns) 678 value_columns = [lit(value) for value in values] 679 680 null_replacement_mapping = { 681 column.alias_or_name: ( 682 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 683 ) 684 for column, value in zip(columns, value_columns) 685 } 686 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 687 null_replacement_columns = [ 688 null_replacement_mapping[column.alias_or_name] for column in all_columns 689 ] 690 new_df = new_df.select(*null_replacement_columns) 691 return new_df 692 693 @operation(Operation.FROM) 694 def replace( 695 self, 696 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 697 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 698 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 699 ) -> DataFrame: 700 from sqlglot.dataframe.sql.functions import lit 701 702 old_values = None 703 new_df = self.copy() 704 all_columns = self._get_outer_select_columns(new_df.expression) 705 all_column_mapping = {column.alias_or_name: column for column in all_columns} 706 707 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 708 if isinstance(to_replace, dict): 709 old_values = list(to_replace) 710 new_values = list(to_replace.values()) 711 elif not old_values and isinstance(to_replace, list): 712 assert isinstance(value, list), "value must be a list since the replacements are a list" 713 assert len(to_replace) == len( 714 value 715 ), "the replacements and values must be the same length" 716 old_values = to_replace 717 new_values = value 718 else: 719 old_values = [to_replace] * len(columns) 720 new_values = [value] * len(columns) 721 old_values = [lit(value) for value in old_values] 722 new_values = [lit(value) for value in new_values] 723 724 replacement_mapping = {} 725 for column in columns: 726 expression = Column(None) 727 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 728 if i == 0: 729 expression = F.when(column == old_value, new_value) 730 else: 731 expression = expression.when(column == old_value, new_value) # type: ignore 732 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 733 column.expression.alias_or_name 734 ) 735 736 replacement_mapping = {**all_column_mapping, **replacement_mapping} 737 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 738 new_df = new_df.select(*replacement_columns) 739 return new_df 740 741 @operation(Operation.SELECT) 742 def withColumn(self, colName: str, col: Column) -> DataFrame: 743 col = self._ensure_and_normalize_col(col) 744 existing_col_names = self.expression.named_selects 745 existing_col_index = ( 746 existing_col_names.index(colName) if colName in existing_col_names else None 747 ) 748 if existing_col_index: 749 expression = self.expression.copy() 750 expression.expressions[existing_col_index] = col.expression 751 return self.copy(expression=expression) 752 return self.copy().select(col.alias(colName), append=True) 753 754 @operation(Operation.SELECT) 755 def withColumnRenamed(self, existing: str, new: str): 756 expression = self.expression.copy() 757 existing_columns = [ 758 expression 759 for expression in expression.expressions 760 if expression.alias_or_name == existing 761 ] 762 if not existing_columns: 763 raise ValueError("Tried to rename a column that doesn't exist") 764 for existing_column in existing_columns: 765 if isinstance(existing_column, exp.Column): 766 existing_column.replace(exp.alias_(existing_column, new)) 767 else: 768 existing_column.set("alias", exp.to_identifier(new)) 769 return self.copy(expression=expression) 770 771 @operation(Operation.SELECT) 772 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 773 all_columns = self._get_outer_select_columns(self.expression) 774 drop_cols = self._ensure_and_normalize_cols(cols) 775 new_columns = [ 776 col 777 for col in all_columns 778 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 779 ] 780 return self.copy().select(*new_columns, append=False) 781 782 @operation(Operation.LIMIT) 783 def limit(self, num: int) -> DataFrame: 784 return self.copy(expression=self.expression.limit(num)) 785 786 @operation(Operation.NO_OP) 787 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 788 parameter_list = ensure_list(parameters) 789 parameter_columns = ( 790 self._ensure_list_of_columns(parameter_list) 791 if parameters 792 else Column.ensure_cols([self.sequence_id]) 793 ) 794 return self._hint(name, parameter_columns) 795 796 @operation(Operation.NO_OP) 797 def repartition( 798 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 799 ) -> DataFrame: 800 num_partition_cols = self._ensure_list_of_columns(numPartitions) 801 columns = self._ensure_and_normalize_cols(cols) 802 args = num_partition_cols + columns 803 return self._hint("repartition", args) 804 805 @operation(Operation.NO_OP) 806 def coalesce(self, numPartitions: int) -> DataFrame: 807 num_partitions = Column.ensure_cols([numPartitions]) 808 return self._hint("coalesce", num_partitions) 809 810 @operation(Operation.NO_OP) 811 def cache(self) -> DataFrame: 812 return self._cache(storage_level="MEMORY_AND_DISK") 813 814 @operation(Operation.NO_OP) 815 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 816 """ 817 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 818 """ 819 return self._cache(storageLevel)
46 def __init__( 47 self, 48 spark: SparkSession, 49 expression: exp.Select, 50 branch_id: t.Optional[str] = None, 51 sequence_id: t.Optional[str] = None, 52 last_op: Operation = Operation.INIT, 53 pending_hints: t.Optional[t.List[exp.Expression]] = None, 54 output_expression_container: t.Optional[OutputExpressionContainer] = None, 55 **kwargs, 56 ): 57 self.spark = spark 58 self.expression = expression 59 self.branch_id = branch_id or self.spark._random_branch_id 60 self.sequence_id = sequence_id or self.spark._random_sequence_id 61 self.last_op = last_op 62 self.pending_hints = pending_hints or [] 63 self.output_expression_container = output_expression_container or exp.Select()
295 def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: 296 df = self._resolve_pending_hints() 297 select_expressions = df._get_select_expressions() 298 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 299 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 300 for expression_type, select_expression in select_expressions: 301 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 302 if optimize: 303 select_expression = t.cast(exp.Select, optimize_func(select_expression)) 304 select_expression = df._replace_cte_names_with_hashes(select_expression) 305 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 306 if expression_type == exp.Cache: 307 cache_table_name = df._create_hash_from_expression(select_expression) 308 cache_table = exp.to_table(cache_table_name) 309 original_alias_name = select_expression.args["cte_alias_name"] 310 311 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 312 cache_table_name 313 ) 314 sqlglot.schema.add_table( 315 cache_table_name, 316 { 317 expression.alias_or_name: expression.type.sql("spark") 318 for expression in select_expression.expressions 319 }, 320 dialect="spark", 321 ) 322 cache_storage_level = select_expression.args["cache_storage_level"] 323 options = [ 324 exp.Literal.string("storageLevel"), 325 exp.Literal.string(cache_storage_level), 326 ] 327 expression = exp.Cache( 328 this=cache_table, expression=select_expression, lazy=True, options=options 329 ) 330 # We will drop the "view" if it exists before running the cache table 331 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 332 elif expression_type == exp.Create: 333 expression = df.output_expression_container.copy() 334 expression.set("expression", select_expression) 335 elif expression_type == exp.Insert: 336 expression = df.output_expression_container.copy() 337 select_without_ctes = select_expression.copy() 338 select_without_ctes.set("with", None) 339 expression.set("expression", select_without_ctes) 340 if select_expression.ctes: 341 expression.set("with", exp.With(expressions=select_expression.ctes)) 342 elif expression_type == exp.Select: 343 expression = select_expression 344 else: 345 raise ValueError(f"Invalid expression type: {expression_type}") 346 output_expressions.append(expression) 347 348 return [ 349 expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions 350 ]
355 @operation(Operation.SELECT) 356 def select(self, *cols, **kwargs) -> DataFrame: 357 cols = self._ensure_and_normalize_cols(cols) 358 kwargs["append"] = kwargs.get("append", False) 359 if self.expression.args.get("joins"): 360 ambiguous_cols = [ 361 col 362 for col in cols 363 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 364 ] 365 if ambiguous_cols: 366 join_table_identifiers = [ 367 x.this for x in get_tables_from_expression_with_join(self.expression) 368 ] 369 cte_names_in_join = [x.this for x in join_table_identifiers] 370 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 371 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 372 # of Spark. 373 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 374 for ambiguous_col in ambiguous_cols: 375 ctes_with_column = [ 376 cte 377 for cte in self.expression.ctes 378 if cte.alias_or_name in cte_names_in_join 379 and ambiguous_col.alias_or_name in cte.this.named_selects 380 ] 381 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 382 # use the same CTE we used before 383 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 384 if cte: 385 resolved_column_position[ambiguous_col] += 1 386 else: 387 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 388 ambiguous_col.expression.set("table", cte.alias_or_name) 389 return self.copy( 390 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 391 )
393 @operation(Operation.NO_OP) 394 def alias(self, name: str, **kwargs) -> DataFrame: 395 new_sequence_id = self.spark._random_sequence_id 396 df = self.copy() 397 for join_hint in df.pending_join_hints: 398 for expression in join_hint.expressions: 399 if expression.alias_or_name == self.sequence_id: 400 expression.set("this", Column.ensure_col(new_sequence_id).expression) 401 df.spark._add_alias_to_mapping(name, new_sequence_id) 402 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
421 @operation(Operation.FROM) 422 def join( 423 self, 424 other_df: DataFrame, 425 on: t.Union[str, t.List[str], Column, t.List[Column]], 426 how: str = "inner", 427 **kwargs, 428 ) -> DataFrame: 429 other_df = other_df._convert_leaf_to_cte() 430 join_columns = self._ensure_list_of_columns(on) 431 # We will determine actual "join on" expression later so we don't provide it at first 432 join_expression = self.expression.join( 433 other_df.latest_cte_name, join_type=how.replace("_", " ") 434 ) 435 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 436 self_columns = self._get_outer_select_columns(join_expression) 437 other_columns = self._get_outer_select_columns(other_df) 438 # Determines the join clause and select columns to be used passed on what type of columns were provided for 439 # the join. The columns returned changes based on how the on expression is provided. 440 if isinstance(join_columns[0].expression, exp.Column): 441 """ 442 Unique characteristics of join on column names only: 443 * The column names are put at the front of the select list 444 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 445 """ 446 table_names = [ 447 table.alias_or_name 448 for table in get_tables_from_expression_with_join(join_expression) 449 ] 450 potential_ctes = [ 451 cte 452 for cte in join_expression.ctes 453 if cte.alias_or_name in table_names 454 and cte.alias_or_name != other_df.latest_cte_name 455 ] 456 # Determine the table to reference for the left side of the join by checking each of the left side 457 # tables and see if they have the column being referenced. 458 join_column_pairs = [] 459 for join_column in join_columns: 460 num_matching_ctes = 0 461 for cte in potential_ctes: 462 if join_column.alias_or_name in cte.this.named_selects: 463 left_column = join_column.copy().set_table_name(cte.alias_or_name) 464 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 465 join_column_pairs.append((left_column, right_column)) 466 num_matching_ctes += 1 467 if num_matching_ctes > 1: 468 raise ValueError( 469 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 470 ) 471 elif num_matching_ctes == 0: 472 raise ValueError( 473 f"Column {join_column.alias_or_name} does not exist in any of the tables." 474 ) 475 join_clause = functools.reduce( 476 lambda x, y: x & y, 477 [left_column == right_column for left_column, right_column in join_column_pairs], 478 ) 479 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 480 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 481 select_column_names = [ 482 column.alias_or_name 483 if not isinstance(column.expression.this, exp.Star) 484 else column.sql() 485 for column in self_columns + other_columns 486 ] 487 select_column_names = [ 488 column_name 489 for column_name in select_column_names 490 if column_name not in join_column_names 491 ] 492 select_column_names = join_column_names + select_column_names 493 else: 494 """ 495 Unique characteristics of join on expressions: 496 * There is no deduplication of the results. 497 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 498 """ 499 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 500 if len(join_columns) > 1: 501 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 502 join_clause = join_columns[0] 503 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 504 505 # Update the on expression with the actual join clause to replace the dummy one from before 506 join_expression.args["joins"][-1].set("on", join_clause.expression) 507 new_df = self.copy(expression=join_expression) 508 new_df.pending_join_hints.extend(self.pending_join_hints) 509 new_df.pending_hints.extend(other_df.pending_hints) 510 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 511 return new_df
513 @operation(Operation.ORDER_BY) 514 def orderBy( 515 self, 516 *cols: t.Union[str, Column], 517 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 518 ) -> DataFrame: 519 """ 520 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 521 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 522 is unlikely to come up. 523 """ 524 columns = self._ensure_and_normalize_cols(cols) 525 pre_ordered_col_indexes = [ 526 x 527 for x in [ 528 i if isinstance(col.expression, exp.Ordered) else None 529 for i, col in enumerate(columns) 530 ] 531 if x is not None 532 ] 533 if ascending is None: 534 ascending = [True] * len(columns) 535 elif not isinstance(ascending, list): 536 ascending = [ascending] * len(columns) 537 ascending = [bool(x) for i, x in enumerate(ascending)] 538 assert len(columns) == len( 539 ascending 540 ), "The length of items in ascending must equal the number of columns provided" 541 col_and_ascending = list(zip(columns, ascending)) 542 order_by_columns = [ 543 exp.Ordered(this=col.expression, desc=not asc) 544 if i not in pre_ordered_col_indexes 545 else columns[i].column_expression 546 for i, (col, asc) in enumerate(col_and_ascending) 547 ] 548 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
513 @operation(Operation.ORDER_BY) 514 def orderBy( 515 self, 516 *cols: t.Union[str, Column], 517 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 518 ) -> DataFrame: 519 """ 520 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 521 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 522 is unlikely to come up. 523 """ 524 columns = self._ensure_and_normalize_cols(cols) 525 pre_ordered_col_indexes = [ 526 x 527 for x in [ 528 i if isinstance(col.expression, exp.Ordered) else None 529 for i, col in enumerate(columns) 530 ] 531 if x is not None 532 ] 533 if ascending is None: 534 ascending = [True] * len(columns) 535 elif not isinstance(ascending, list): 536 ascending = [ascending] * len(columns) 537 ascending = [bool(x) for i, x in enumerate(ascending)] 538 assert len(columns) == len( 539 ascending 540 ), "The length of items in ascending must equal the number of columns provided" 541 col_and_ascending = list(zip(columns, ascending)) 542 order_by_columns = [ 543 exp.Ordered(this=col.expression, desc=not asc) 544 if i not in pre_ordered_col_indexes 545 else columns[i].column_expression 546 for i, (col, asc) in enumerate(col_and_ascending) 547 ] 548 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
558 @operation(Operation.FROM) 559 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 560 l_columns = self.columns 561 r_columns = other.columns 562 if not allowMissingColumns: 563 l_expressions = l_columns 564 r_expressions = l_columns 565 else: 566 l_expressions = [] 567 r_expressions = [] 568 r_columns_unused = copy(r_columns) 569 for l_column in l_columns: 570 l_expressions.append(l_column) 571 if l_column in r_columns: 572 r_expressions.append(l_column) 573 r_columns_unused.remove(l_column) 574 else: 575 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 576 for r_column in r_columns_unused: 577 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 578 r_expressions.append(r_column) 579 r_df = ( 580 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 581 ) 582 l_df = self.copy() 583 if allowMissingColumns: 584 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 585 return l_df._set_operation(exp.Union, r_df, False)
603 @operation(Operation.SELECT) 604 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 605 if not subset: 606 return self.distinct() 607 column_names = ensure_list(subset) 608 window = Window.partitionBy(*column_names).orderBy(*column_names) 609 return ( 610 self.copy() 611 .withColumn("row_num", F.row_number().over(window)) 612 .where(F.col("row_num") == F.lit(1)) 613 .drop("row_num") 614 )
616 @operation(Operation.FROM) 617 def dropna( 618 self, 619 how: str = "any", 620 thresh: t.Optional[int] = None, 621 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 622 ) -> DataFrame: 623 minimum_non_null = thresh or 0 # will be determined later if thresh is null 624 new_df = self.copy() 625 all_columns = self._get_outer_select_columns(new_df.expression) 626 if subset: 627 null_check_columns = self._ensure_and_normalize_cols(subset) 628 else: 629 null_check_columns = all_columns 630 if thresh is None: 631 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 632 else: 633 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 634 if minimum_num_nulls > len(null_check_columns): 635 raise RuntimeError( 636 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 637 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 638 ) 639 if_null_checks = [ 640 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 641 ] 642 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 643 num_nulls = nulls_added_together.alias("num_nulls") 644 new_df = new_df.select(num_nulls, append=True) 645 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 646 final_df = filtered_df.select(*all_columns) 647 return final_df
649 @operation(Operation.FROM) 650 def fillna( 651 self, 652 value: t.Union[ColumnLiterals], 653 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 654 ) -> DataFrame: 655 """ 656 Functionality Difference: If you provide a value to replace a null and that type conflicts 657 with the type of the column then PySpark will just ignore your replacement. 658 This will try to cast them to be the same in some cases. So they won't always match. 659 Best to not mix types so make sure replacement is the same type as the column 660 661 Possibility for improvement: Use `typeof` function to get the type of the column 662 and check if it matches the type of the value provided. If not then make it null. 663 """ 664 from sqlglot.dataframe.sql.functions import lit 665 666 values = None 667 columns = None 668 new_df = self.copy() 669 all_columns = self._get_outer_select_columns(new_df.expression) 670 all_column_mapping = {column.alias_or_name: column for column in all_columns} 671 if isinstance(value, dict): 672 values = list(value.values()) 673 columns = self._ensure_and_normalize_cols(list(value)) 674 if not columns: 675 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 676 if not values: 677 values = [value] * len(columns) 678 value_columns = [lit(value) for value in values] 679 680 null_replacement_mapping = { 681 column.alias_or_name: ( 682 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 683 ) 684 for column, value in zip(columns, value_columns) 685 } 686 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 687 null_replacement_columns = [ 688 null_replacement_mapping[column.alias_or_name] for column in all_columns 689 ] 690 new_df = new_df.select(*null_replacement_columns) 691 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
693 @operation(Operation.FROM) 694 def replace( 695 self, 696 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 697 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 698 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 699 ) -> DataFrame: 700 from sqlglot.dataframe.sql.functions import lit 701 702 old_values = None 703 new_df = self.copy() 704 all_columns = self._get_outer_select_columns(new_df.expression) 705 all_column_mapping = {column.alias_or_name: column for column in all_columns} 706 707 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 708 if isinstance(to_replace, dict): 709 old_values = list(to_replace) 710 new_values = list(to_replace.values()) 711 elif not old_values and isinstance(to_replace, list): 712 assert isinstance(value, list), "value must be a list since the replacements are a list" 713 assert len(to_replace) == len( 714 value 715 ), "the replacements and values must be the same length" 716 old_values = to_replace 717 new_values = value 718 else: 719 old_values = [to_replace] * len(columns) 720 new_values = [value] * len(columns) 721 old_values = [lit(value) for value in old_values] 722 new_values = [lit(value) for value in new_values] 723 724 replacement_mapping = {} 725 for column in columns: 726 expression = Column(None) 727 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 728 if i == 0: 729 expression = F.when(column == old_value, new_value) 730 else: 731 expression = expression.when(column == old_value, new_value) # type: ignore 732 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 733 column.expression.alias_or_name 734 ) 735 736 replacement_mapping = {**all_column_mapping, **replacement_mapping} 737 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 738 new_df = new_df.select(*replacement_columns) 739 return new_df
741 @operation(Operation.SELECT) 742 def withColumn(self, colName: str, col: Column) -> DataFrame: 743 col = self._ensure_and_normalize_col(col) 744 existing_col_names = self.expression.named_selects 745 existing_col_index = ( 746 existing_col_names.index(colName) if colName in existing_col_names else None 747 ) 748 if existing_col_index: 749 expression = self.expression.copy() 750 expression.expressions[existing_col_index] = col.expression 751 return self.copy(expression=expression) 752 return self.copy().select(col.alias(colName), append=True)
754 @operation(Operation.SELECT) 755 def withColumnRenamed(self, existing: str, new: str): 756 expression = self.expression.copy() 757 existing_columns = [ 758 expression 759 for expression in expression.expressions 760 if expression.alias_or_name == existing 761 ] 762 if not existing_columns: 763 raise ValueError("Tried to rename a column that doesn't exist") 764 for existing_column in existing_columns: 765 if isinstance(existing_column, exp.Column): 766 existing_column.replace(exp.alias_(existing_column, new)) 767 else: 768 existing_column.set("alias", exp.to_identifier(new)) 769 return self.copy(expression=expression)
771 @operation(Operation.SELECT) 772 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 773 all_columns = self._get_outer_select_columns(self.expression) 774 drop_cols = self._ensure_and_normalize_cols(cols) 775 new_columns = [ 776 col 777 for col in all_columns 778 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 779 ] 780 return self.copy().select(*new_columns, append=False)
786 @operation(Operation.NO_OP) 787 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 788 parameter_list = ensure_list(parameters) 789 parameter_columns = ( 790 self._ensure_list_of_columns(parameter_list) 791 if parameters 792 else Column.ensure_cols([self.sequence_id]) 793 ) 794 return self._hint(name, parameter_columns)
796 @operation(Operation.NO_OP) 797 def repartition( 798 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 799 ) -> DataFrame: 800 num_partition_cols = self._ensure_list_of_columns(numPartitions) 801 columns = self._ensure_and_normalize_cols(cols) 802 args = num_partition_cols + columns 803 return self._hint("repartition", args)
814 @operation(Operation.NO_OP) 815 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 816 """ 817 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 818 """ 819 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
17class Column: 18 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 19 if isinstance(expression, Column): 20 expression = expression.expression # type: ignore 21 elif expression is None or not isinstance(expression, (str, exp.Expression)): 22 expression = self._lit(expression).expression # type: ignore 23 24 expression = sqlglot.maybe_parse(expression, dialect="spark") 25 if expression is None: 26 raise ValueError(f"Could not parse {expression}") 27 28 if isinstance(expression, exp.Column): 29 expression.transform(Spark.normalize_identifier, copy=False) 30 31 self.expression: exp.Expression = expression 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: [Column.ensure_col(x).expression for x in v] 149 if is_iterable(v) 150 else Column.ensure_col(v).expression 151 for k, v in kwargs.items() 152 if v is not None 153 } 154 new_expression = ( 155 callable_expression(**ensure_expression_values) 156 if ensured_column is None 157 else callable_expression( 158 this=ensured_column.column_expression, **ensure_expression_values 159 ) 160 ) 161 return Column(new_expression) 162 163 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 164 return Column( 165 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 166 ) 167 168 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 169 return Column( 170 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 171 ) 172 173 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 174 return Column(klass(this=self.column_expression, **kwargs)) 175 176 @property 177 def is_alias(self): 178 return isinstance(self.expression, exp.Alias) 179 180 @property 181 def is_column(self): 182 return isinstance(self.expression, exp.Column) 183 184 @property 185 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 186 return self.expression.unalias() 187 188 @property 189 def alias_or_name(self) -> str: 190 return self.expression.alias_or_name 191 192 @classmethod 193 def ensure_literal(cls, value) -> Column: 194 from sqlglot.dataframe.sql.functions import lit 195 196 if isinstance(value, cls): 197 value = value.expression 198 if not isinstance(value, exp.Literal): 199 return lit(value) 200 return Column(value) 201 202 def copy(self) -> Column: 203 return Column(self.expression.copy()) 204 205 def set_table_name(self, table_name: str, copy=False) -> Column: 206 expression = self.expression.copy() if copy else self.expression 207 expression.set("table", exp.to_identifier(table_name)) 208 return Column(expression) 209 210 def sql(self, **kwargs) -> str: 211 return self.expression.sql(**{"dialect": "spark", **kwargs}) 212 213 def alias(self, name: str) -> Column: 214 new_expression = exp.alias_(self.column_expression, name) 215 return Column(new_expression) 216 217 def asc(self) -> Column: 218 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 219 return Column(new_expression) 220 221 def desc(self) -> Column: 222 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 223 return Column(new_expression) 224 225 asc_nulls_first = asc 226 227 def asc_nulls_last(self) -> Column: 228 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 229 return Column(new_expression) 230 231 def desc_nulls_first(self) -> Column: 232 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 233 return Column(new_expression) 234 235 desc_nulls_last = desc 236 237 def when(self, condition: Column, value: t.Any) -> Column: 238 from sqlglot.dataframe.sql.functions import when 239 240 column_with_if = when(condition, value) 241 if not isinstance(self.expression, exp.Case): 242 return column_with_if 243 new_column = self.copy() 244 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 245 return new_column 246 247 def otherwise(self, value: t.Any) -> Column: 248 from sqlglot.dataframe.sql.functions import lit 249 250 true_value = value if isinstance(value, Column) else lit(value) 251 new_column = self.copy() 252 new_column.expression.set("default", true_value.column_expression) 253 return new_column 254 255 def isNull(self) -> Column: 256 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 257 return Column(new_expression) 258 259 def isNotNull(self) -> Column: 260 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 261 return Column(new_expression) 262 263 def cast(self, dataType: t.Union[str, DataType]): 264 """ 265 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 266 Sqlglot doesn't currently replicate this class so it only accepts a string 267 """ 268 if isinstance(dataType, DataType): 269 dataType = dataType.simpleString() 270 return Column(exp.cast(self.column_expression, dataType, dialect="spark")) 271 272 def startswith(self, value: t.Union[str, Column]) -> Column: 273 value = self._lit(value) if not isinstance(value, Column) else value 274 return self.invoke_anonymous_function(self, "STARTSWITH", value) 275 276 def endswith(self, value: t.Union[str, Column]) -> Column: 277 value = self._lit(value) if not isinstance(value, Column) else value 278 return self.invoke_anonymous_function(self, "ENDSWITH", value) 279 280 def rlike(self, regexp: str) -> Column: 281 return self.invoke_expression_over_column( 282 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 283 ) 284 285 def like(self, other: str): 286 return self.invoke_expression_over_column( 287 self, exp.Like, expression=self._lit(other).expression 288 ) 289 290 def ilike(self, other: str): 291 return self.invoke_expression_over_column( 292 self, exp.ILike, expression=self._lit(other).expression 293 ) 294 295 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 296 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 297 length = self._lit(length) if not isinstance(length, Column) else length 298 return Column.invoke_expression_over_column( 299 self, exp.Substring, start=startPos.expression, length=length.expression 300 ) 301 302 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 303 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 304 expressions = [self._lit(x).expression for x in columns] 305 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 306 307 def between( 308 self, 309 lowerBound: t.Union[ColumnOrLiteral], 310 upperBound: t.Union[ColumnOrLiteral], 311 ) -> Column: 312 lower_bound_exp = ( 313 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 314 ) 315 upper_bound_exp = ( 316 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 317 ) 318 return Column( 319 exp.Between( 320 this=self.column_expression, 321 low=lower_bound_exp.expression, 322 high=upper_bound_exp.expression, 323 ) 324 ) 325 326 def over(self, window: WindowSpec) -> Column: 327 window_expression = window.expression.copy() 328 window_expression.set("this", self.column_expression) 329 return Column(window_expression)
18 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 19 if isinstance(expression, Column): 20 expression = expression.expression # type: ignore 21 elif expression is None or not isinstance(expression, (str, exp.Expression)): 22 expression = self._lit(expression).expression # type: ignore 23 24 expression = sqlglot.maybe_parse(expression, dialect="spark") 25 if expression is None: 26 raise ValueError(f"Could not parse {expression}") 27 28 if isinstance(expression, exp.Column): 29 expression.transform(Spark.normalize_identifier, copy=False) 30 31 self.expression: exp.Expression = expression
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: [Column.ensure_col(x).expression for x in v] 149 if is_iterable(v) 150 else Column.ensure_col(v).expression 151 for k, v in kwargs.items() 152 if v is not None 153 } 154 new_expression = ( 155 callable_expression(**ensure_expression_values) 156 if ensured_column is None 157 else callable_expression( 158 this=ensured_column.column_expression, **ensure_expression_values 159 ) 160 ) 161 return Column(new_expression)
237 def when(self, condition: Column, value: t.Any) -> Column: 238 from sqlglot.dataframe.sql.functions import when 239 240 column_with_if = when(condition, value) 241 if not isinstance(self.expression, exp.Case): 242 return column_with_if 243 new_column = self.copy() 244 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 245 return new_column
263 def cast(self, dataType: t.Union[str, DataType]): 264 """ 265 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 266 Sqlglot doesn't currently replicate this class so it only accepts a string 267 """ 268 if isinstance(dataType, DataType): 269 dataType = dataType.simpleString() 270 return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
295 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 296 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 297 length = self._lit(length) if not isinstance(length, Column) else length 298 return Column.invoke_expression_over_column( 299 self, exp.Substring, start=startPos.expression, length=length.expression 300 )
302 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 303 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 304 expressions = [self._lit(x).expression for x in columns] 305 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
307 def between( 308 self, 309 lowerBound: t.Union[ColumnOrLiteral], 310 upperBound: t.Union[ColumnOrLiteral], 311 ) -> Column: 312 lower_bound_exp = ( 313 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 314 ) 315 upper_bound_exp = ( 316 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 317 ) 318 return Column( 319 exp.Between( 320 this=self.column_expression, 321 low=lower_bound_exp.expression, 322 high=upper_bound_exp.expression, 323 ) 324 )
822class DataFrameNaFunctions: 823 def __init__(self, df: DataFrame): 824 self.df = df 825 826 def drop( 827 self, 828 how: str = "any", 829 thresh: t.Optional[int] = None, 830 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 831 ) -> DataFrame: 832 return self.df.dropna(how=how, thresh=thresh, subset=subset) 833 834 def fill( 835 self, 836 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 837 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 838 ) -> DataFrame: 839 return self.df.fillna(value=value, subset=subset) 840 841 def replace( 842 self, 843 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 844 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 845 subset: t.Optional[t.Union[str, t.List[str]]] = None, 846 ) -> DataFrame: 847 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
841 def replace( 842 self, 843 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 844 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 845 subset: t.Optional[t.Union[str, t.List[str]]] = None, 846 ) -> DataFrame: 847 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 return self.expression.sql(dialect="spark", **kwargs) 53 54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec 64 65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec 77 78 def _calc_start_end( 79 self, start: int, end: int 80 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 81 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 82 "start_side": None, 83 "end_side": None, 84 } 85 if start == Window.currentRow: 86 kwargs["start"] = "CURRENT ROW" 87 else: 88 kwargs = { 89 **kwargs, 90 **{ 91 "start_side": "PRECEDING", 92 "start": "UNBOUNDED" 93 if start <= Window.unboundedPreceding 94 else F.lit(start).expression, 95 }, 96 } 97 if end == Window.currentRow: 98 kwargs["end"] = "CURRENT ROW" 99 else: 100 kwargs = { 101 **kwargs, 102 **{ 103 "end_side": "FOLLOWING", 104 "end": "UNBOUNDED" 105 if end >= Window.unboundedFollowing 106 else F.lit(end).expression, 107 }, 108 } 109 return kwargs 110 111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec 122 123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
54 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 55 from sqlglot.dataframe.sql.column import Column 56 57 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 58 expressions = [Column.ensure_col(x).expression for x in cols] 59 window_spec = self.copy() 60 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 61 partition_by_expressions.extend(expressions) 62 window_spec.expression.set("partition_by", partition_by_expressions) 63 return window_spec
65 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 66 from sqlglot.dataframe.sql.column import Column 67 68 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 69 expressions = [Column.ensure_col(x).expression for x in cols] 70 window_spec = self.copy() 71 if window_spec.expression.args.get("order") is None: 72 window_spec.expression.set("order", exp.Order(expressions=[])) 73 order_by = window_spec.expression.args["order"].expressions 74 order_by.extend(expressions) 75 window_spec.expression.args["order"].set("expressions", order_by) 76 return window_spec
111 def rowsBetween(self, start: int, end: int) -> WindowSpec: 112 window_spec = self.copy() 113 spec = self._calc_start_end(start, end) 114 spec["kind"] = "ROWS" 115 window_spec.expression.set( 116 "spec", 117 exp.WindowSpec( 118 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 119 ), 120 ) 121 return window_spec
123 def rangeBetween(self, start: int, end: int) -> WindowSpec: 124 window_spec = self.copy() 125 spec = self._calc_start_end(start, end) 126 spec["kind"] = "RANGE" 127 window_spec.expression.set( 128 "spec", 129 exp.WindowSpec( 130 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 131 ), 132 ) 133 return window_spec
16class DataFrameReader: 17 def __init__(self, spark: SparkSession): 18 self.spark = spark 19 20 def table(self, tableName: str) -> DataFrame: 21 from sqlglot.dataframe.sql.dataframe import DataFrame 22 23 sqlglot.schema.add_table(tableName, dialect="spark") 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) 29 .select( 30 *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) 31 ), 32 )
20 def table(self, tableName: str) -> DataFrame: 21 from sqlglot.dataframe.sql.dataframe import DataFrame 22 23 sqlglot.schema.add_table(tableName, dialect="spark") 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) 29 .select( 30 *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) 31 ), 32 )
35class DataFrameWriter: 36 def __init__( 37 self, 38 df: DataFrame, 39 spark: t.Optional[SparkSession] = None, 40 mode: t.Optional[str] = None, 41 by_name: bool = False, 42 ): 43 self._df = df 44 self._spark = spark or df.spark 45 self._mode = mode 46 self._by_name = by_name 47 48 def copy(self, **kwargs) -> DataFrameWriter: 49 return DataFrameWriter( 50 **{ 51 k[1:] if k.startswith("_") else k: v 52 for k, v in object_to_dict(self, **kwargs).items() 53 } 54 ) 55 56 def sql(self, **kwargs) -> t.List[str]: 57 return self._df.sql(**kwargs) 58 59 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 60 return self.copy(_mode=saveMode) 61 62 @property 63 def byName(self): 64 return self.copy(by_name=True) 65 66 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 67 output_expression_container = exp.Insert( 68 **{ 69 "this": exp.to_table(tableName), 70 "overwrite": overwrite, 71 } 72 ) 73 df = self._df.copy(output_expression_container=output_expression_container) 74 if self._by_name: 75 columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") 76 df = df._convert_leaf_to_cte().select(*columns) 77 78 return self.copy(_df=df) 79 80 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 81 if format is not None: 82 raise NotImplementedError("Providing Format in the save as table is not supported") 83 exists, replace, mode = None, None, mode or str(self._mode) 84 if mode == "append": 85 return self.insertInto(name) 86 if mode == "ignore": 87 exists = True 88 if mode == "overwrite": 89 replace = True 90 output_expression_container = exp.Create( 91 this=exp.to_table(name), 92 kind="TABLE", 93 exists=exists, 94 replace=replace, 95 ) 96 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
66 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 67 output_expression_container = exp.Insert( 68 **{ 69 "this": exp.to_table(tableName), 70 "overwrite": overwrite, 71 } 72 ) 73 df = self._df.copy(output_expression_container=output_expression_container) 74 if self._by_name: 75 columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") 76 df = df._convert_leaf_to_cte().select(*columns) 77 78 return self.copy(_df=df)
80 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 81 if format is not None: 82 raise NotImplementedError("Providing Format in the save as table is not supported") 83 exists, replace, mode = None, None, mode or str(self._mode) 84 if mode == "append": 85 return self.insertInto(name) 86 if mode == "ignore": 87 exists = True 88 if mode == "overwrite": 89 replace = True 90 output_expression_container = exp.Create( 91 this=exp.to_table(name), 92 kind="TABLE", 93 exists=exists, 94 replace=replace, 95 ) 96 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))