sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
21class SparkSession: 22 DEFAULT_DIALECT = "spark" 23 _instance = None 24 25 def __init__(self): 26 if not hasattr(self, "known_ids"): 27 self.known_ids = set() 28 self.known_branch_ids = set() 29 self.known_sequence_ids = set() 30 self.name_to_sequence_id_mapping = defaultdict(list) 31 self.incrementing_id = 1 32 self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() 33 34 def __new__(cls, *args, **kwargs) -> SparkSession: 35 if cls._instance is None: 36 cls._instance = super().__new__(cls) 37 return cls._instance 38 39 @property 40 def read(self) -> DataFrameReader: 41 return DataFrameReader(self) 42 43 def table(self, tableName: str) -> DataFrame: 44 return self.read.table(tableName) 45 46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.Tuple( 75 expressions=list( 76 map( 77 lambda x: F.lit(x).expression, 78 row if not isinstance(row, dict) else row.values(), 79 ) 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 F.col(name).cast(data_type).alias(name).expression 87 if data_type is not None 88 else F.col(name).expression 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression) 107 108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df 126 127 @property 128 def _auto_incrementing_name(self) -> str: 129 name = f"a{self.incrementing_id}" 130 self.incrementing_id += 1 131 return name 132 133 @property 134 def _random_branch_id(self) -> str: 135 id = self._random_id 136 self.known_branch_ids.add(id) 137 return id 138 139 @property 140 def _random_sequence_id(self): 141 id = self._random_id 142 self.known_sequence_ids.add(id) 143 return id 144 145 @property 146 def _random_id(self) -> str: 147 id = "r" + uuid.uuid4().hex 148 self.known_ids.add(id) 149 return id 150 151 @property 152 def _join_hint_names(self) -> t.Set[str]: 153 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 154 155 def _add_alias_to_mapping(self, name: str, sequence_id: str): 156 self.name_to_sequence_id_mapping[name].append(sequence_id) 157 158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect)() 187 return spark 188 189 @classproperty 190 def builder(cls) -> Builder: 191 return cls.Builder()
46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.Tuple( 75 expressions=list( 76 map( 77 lambda x: F.lit(x).expression, 78 row if not isinstance(row, dict) else row.values(), 79 ) 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 F.col(name).cast(data_type).alias(name).expression 87 if data_type is not None 88 else F.col(name).expression 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression)
108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df
158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect)() 187 return spark
170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self
49class DataFrame: 50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select() 68 69 def __getattr__(self, column_name: str) -> Column: 70 return self[column_name] 71 72 def __getitem__(self, column_name: str) -> Column: 73 column_name = f"{self.branch_id}.{column_name}" 74 return Column(column_name) 75 76 def __copy__(self): 77 return self.copy() 78 79 @property 80 def sparkSession(self): 81 return self.spark 82 83 @property 84 def write(self): 85 return DataFrameWriter(self) 86 87 @property 88 def latest_cte_name(self) -> str: 89 if not self.expression.ctes: 90 from_exp = self.expression.args["from"] 91 if from_exp.alias_or_name: 92 return from_exp.alias_or_name 93 table_alias = from_exp.find(exp.TableAlias) 94 if not table_alias: 95 raise RuntimeError( 96 f"Could not find an alias name for this expression: {self.expression}" 97 ) 98 return table_alias.alias_or_name 99 return self.expression.ctes[-1].alias 100 101 @property 102 def pending_join_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 104 105 @property 106 def pending_partition_hints(self): 107 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 108 109 @property 110 def columns(self) -> t.List[str]: 111 return self.expression.named_selects 112 113 @property 114 def na(self) -> DataFrameNaFunctions: 115 return DataFrameNaFunctions(self) 116 117 def _replace_cte_names_with_hashes(self, expression: exp.Select): 118 replacement_mapping = {} 119 for cte in expression.ctes: 120 old_name_id = cte.args["alias"].this 121 new_hashed_id = exp.to_identifier( 122 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 123 ) 124 replacement_mapping[old_name_id] = new_hashed_id 125 expression = expression.transform(replace_id_value, replacement_mapping) 126 return expression 127 128 def _create_cte_from_expression( 129 self, 130 expression: exp.Expression, 131 branch_id: t.Optional[str] = None, 132 sequence_id: t.Optional[str] = None, 133 **kwargs, 134 ) -> t.Tuple[exp.CTE, str]: 135 name = self._create_hash_from_expression(expression) 136 expression_to_cte = expression.copy() 137 expression_to_cte.set("with", None) 138 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 139 cte.set("branch_id", branch_id or self.branch_id) 140 cte.set("sequence_id", sequence_id or self.sequence_id) 141 return cte, name 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 145 ... 146 147 @t.overload 148 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 149 ... 150 151 def _ensure_list_of_columns(self, cols): 152 return Column.ensure_cols(ensure_list(cols)) 153 154 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 155 cols = self._ensure_list_of_columns(cols) 156 normalize(self.spark, expression or self.expression, cols) 157 return cols 158 159 def _ensure_and_normalize_col(self, col): 160 col = Column.ensure_col(col) 161 normalize(self.spark, self.expression, col) 162 return col 163 164 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 165 df = self._resolve_pending_hints() 166 sequence_id = sequence_id or df.sequence_id 167 expression = df.expression.copy() 168 cte_expression, cte_name = df._create_cte_from_expression( 169 expression=expression, sequence_id=sequence_id 170 ) 171 new_expression = df._add_ctes_to_expression( 172 exp.Select(), expression.ctes + [cte_expression] 173 ) 174 sel_columns = df._get_outer_select_columns(cte_expression) 175 new_expression = new_expression.from_(cte_name).select( 176 *[x.alias_or_name for x in sel_columns] 177 ) 178 return df.copy(expression=new_expression, sequence_id=sequence_id) 179 180 def _resolve_pending_hints(self) -> DataFrame: 181 df = self.copy() 182 if not self.pending_hints: 183 return df 184 expression = df.expression 185 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 186 for hint in df.pending_partition_hints: 187 hint_expression.append("expressions", hint) 188 df.pending_hints.remove(hint) 189 190 join_aliases = { 191 join_table.alias_or_name 192 for join_table in get_tables_from_expression_with_join(expression) 193 } 194 if join_aliases: 195 for hint in df.pending_join_hints: 196 for sequence_id_expression in hint.expressions: 197 sequence_id_or_name = sequence_id_expression.alias_or_name 198 sequence_ids_to_match = [sequence_id_or_name] 199 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 200 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 201 sequence_id_or_name 202 ] 203 matching_ctes = [ 204 cte 205 for cte in reversed(expression.ctes) 206 if cte.args["sequence_id"] in sequence_ids_to_match 207 ] 208 for matching_cte in matching_ctes: 209 if matching_cte.alias_or_name in join_aliases: 210 sequence_id_expression.set("this", matching_cte.args["alias"].this) 211 df.pending_hints.remove(hint) 212 break 213 hint_expression.append("expressions", hint) 214 if hint_expression.expressions: 215 expression.set("hint", hint_expression) 216 return df 217 218 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 219 hint_name = hint_name.upper() 220 hint_expression = ( 221 exp.JoinHint( 222 this=hint_name, 223 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 224 ) 225 if hint_name in JOIN_HINTS 226 else exp.Anonymous( 227 this=hint_name, expressions=[parameter.expression for parameter in args] 228 ) 229 ) 230 new_df = self.copy() 231 new_df.pending_hints.append(hint_expression) 232 return new_df 233 234 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 235 other_df = other._convert_leaf_to_cte() 236 base_expression = self.expression.copy() 237 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 238 all_ctes = base_expression.ctes 239 other_df.expression.set("with", None) 240 base_expression.set("with", None) 241 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 242 operation.set("with", exp.With(expressions=all_ctes)) 243 return self.copy(expression=operation)._convert_leaf_to_cte() 244 245 def _cache(self, storage_level: str): 246 df = self._convert_leaf_to_cte() 247 df.expression.ctes[-1].set("cache_storage_level", storage_level) 248 return df 249 250 @classmethod 251 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 252 expression = expression.copy() 253 with_expression = expression.args.get("with") 254 if with_expression: 255 existing_ctes = with_expression.expressions 256 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 257 for cte in ctes: 258 if cte.alias_or_name not in existsing_cte_names: 259 existing_ctes.append(cte) 260 else: 261 existing_ctes = ctes 262 expression.set("with", exp.With(expressions=existing_ctes)) 263 return expression 264 265 @classmethod 266 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 267 expression = item.expression if isinstance(item, DataFrame) else item 268 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 269 270 @classmethod 271 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 272 from sqlglot.dataframe.sql.session import SparkSession 273 274 value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") 275 return f"t{zlib.crc32(value)}"[:6] 276 277 def _get_select_expressions( 278 self, 279 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 280 select_expressions: t.List[ 281 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 282 ] = [] 283 main_select_ctes: t.List[exp.CTE] = [] 284 for cte in self.expression.ctes: 285 cache_storage_level = cte.args.get("cache_storage_level") 286 if cache_storage_level: 287 select_expression = cte.this.copy() 288 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 289 select_expression.set("cte_alias_name", cte.alias_or_name) 290 select_expression.set("cache_storage_level", cache_storage_level) 291 select_expressions.append((exp.Cache, select_expression)) 292 else: 293 main_select_ctes.append(cte) 294 main_select = self.expression.copy() 295 if main_select_ctes: 296 main_select.set("with", exp.With(expressions=main_select_ctes)) 297 expression_select_pair = (type(self.output_expression_container), main_select) 298 select_expressions.append(expression_select_pair) # type: ignore 299 return select_expressions 300 301 def sql( 302 self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs 303 ) -> t.List[str]: 304 from sqlglot.dataframe.sql.session import SparkSession 305 306 if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: 307 logger.warning( 308 f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." 309 ) 310 df = self._resolve_pending_hints() 311 select_expressions = df._get_select_expressions() 312 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 313 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 314 for expression_type, select_expression in select_expressions: 315 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 316 if optimize: 317 quote_identifiers(select_expression) 318 select_expression = t.cast( 319 exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) 320 ) 321 select_expression = df._replace_cte_names_with_hashes(select_expression) 322 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 323 if expression_type == exp.Cache: 324 cache_table_name = df._create_hash_from_expression(select_expression) 325 cache_table = exp.to_table(cache_table_name) 326 original_alias_name = select_expression.args["cte_alias_name"] 327 328 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 329 cache_table_name 330 ) 331 sqlglot.schema.add_table( 332 cache_table_name, 333 { 334 expression.alias_or_name: expression.type.sql( 335 dialect=SparkSession().dialect 336 ) 337 for expression in select_expression.expressions 338 }, 339 dialect=SparkSession().dialect, 340 ) 341 cache_storage_level = select_expression.args["cache_storage_level"] 342 options = [ 343 exp.Literal.string("storageLevel"), 344 exp.Literal.string(cache_storage_level), 345 ] 346 expression = exp.Cache( 347 this=cache_table, expression=select_expression, lazy=True, options=options 348 ) 349 # We will drop the "view" if it exists before running the cache table 350 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 351 elif expression_type == exp.Create: 352 expression = df.output_expression_container.copy() 353 expression.set("expression", select_expression) 354 elif expression_type == exp.Insert: 355 expression = df.output_expression_container.copy() 356 select_without_ctes = select_expression.copy() 357 select_without_ctes.set("with", None) 358 expression.set("expression", select_without_ctes) 359 if select_expression.ctes: 360 expression.set("with", exp.With(expressions=select_expression.ctes)) 361 elif expression_type == exp.Select: 362 expression = select_expression 363 else: 364 raise ValueError(f"Invalid expression type: {expression_type}") 365 output_expressions.append(expression) 366 367 return [ 368 expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 369 for expression in output_expressions 370 ] 371 372 def copy(self, **kwargs) -> DataFrame: 373 return DataFrame(**object_to_dict(self, **kwargs)) 374 375 @operation(Operation.SELECT) 376 def select(self, *cols, **kwargs) -> DataFrame: 377 cols = self._ensure_and_normalize_cols(cols) 378 kwargs["append"] = kwargs.get("append", False) 379 if self.expression.args.get("joins"): 380 ambiguous_cols = [ 381 col 382 for col in cols 383 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 384 ] 385 if ambiguous_cols: 386 join_table_identifiers = [ 387 x.this for x in get_tables_from_expression_with_join(self.expression) 388 ] 389 cte_names_in_join = [x.this for x in join_table_identifiers] 390 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 391 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 392 # of Spark. 393 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 394 for ambiguous_col in ambiguous_cols: 395 ctes_with_column = [ 396 cte 397 for cte in self.expression.ctes 398 if cte.alias_or_name in cte_names_in_join 399 and ambiguous_col.alias_or_name in cte.this.named_selects 400 ] 401 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 402 # use the same CTE we used before 403 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 404 if cte: 405 resolved_column_position[ambiguous_col] += 1 406 else: 407 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 408 ambiguous_col.expression.set("table", cte.alias_or_name) 409 return self.copy( 410 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 411 ) 412 413 @operation(Operation.NO_OP) 414 def alias(self, name: str, **kwargs) -> DataFrame: 415 new_sequence_id = self.spark._random_sequence_id 416 df = self.copy() 417 for join_hint in df.pending_join_hints: 418 for expression in join_hint.expressions: 419 if expression.alias_or_name == self.sequence_id: 420 expression.set("this", Column.ensure_col(new_sequence_id).expression) 421 df.spark._add_alias_to_mapping(name, new_sequence_id) 422 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 423 424 @operation(Operation.WHERE) 425 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 426 col = self._ensure_and_normalize_col(column) 427 return self.copy(expression=self.expression.where(col.expression)) 428 429 filter = where 430 431 @operation(Operation.GROUP_BY) 432 def groupBy(self, *cols, **kwargs) -> GroupedData: 433 columns = self._ensure_and_normalize_cols(cols) 434 return GroupedData(self, columns, self.last_op) 435 436 @operation(Operation.SELECT) 437 def agg(self, *exprs, **kwargs) -> DataFrame: 438 cols = self._ensure_and_normalize_cols(exprs) 439 return self.groupBy().agg(*cols) 440 441 @operation(Operation.FROM) 442 def join( 443 self, 444 other_df: DataFrame, 445 on: t.Union[str, t.List[str], Column, t.List[Column]], 446 how: str = "inner", 447 **kwargs, 448 ) -> DataFrame: 449 other_df = other_df._convert_leaf_to_cte() 450 join_columns = self._ensure_list_of_columns(on) 451 # We will determine actual "join on" expression later so we don't provide it at first 452 join_expression = self.expression.join( 453 other_df.latest_cte_name, join_type=how.replace("_", " ") 454 ) 455 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 456 self_columns = self._get_outer_select_columns(join_expression) 457 other_columns = self._get_outer_select_columns(other_df) 458 # Determines the join clause and select columns to be used passed on what type of columns were provided for 459 # the join. The columns returned changes based on how the on expression is provided. 460 if isinstance(join_columns[0].expression, exp.Column): 461 """ 462 Unique characteristics of join on column names only: 463 * The column names are put at the front of the select list 464 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 465 """ 466 table_names = [ 467 table.alias_or_name 468 for table in get_tables_from_expression_with_join(join_expression) 469 ] 470 potential_ctes = [ 471 cte 472 for cte in join_expression.ctes 473 if cte.alias_or_name in table_names 474 and cte.alias_or_name != other_df.latest_cte_name 475 ] 476 # Determine the table to reference for the left side of the join by checking each of the left side 477 # tables and see if they have the column being referenced. 478 join_column_pairs = [] 479 for join_column in join_columns: 480 num_matching_ctes = 0 481 for cte in potential_ctes: 482 if join_column.alias_or_name in cte.this.named_selects: 483 left_column = join_column.copy().set_table_name(cte.alias_or_name) 484 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 485 join_column_pairs.append((left_column, right_column)) 486 num_matching_ctes += 1 487 if num_matching_ctes > 1: 488 raise ValueError( 489 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 490 ) 491 elif num_matching_ctes == 0: 492 raise ValueError( 493 f"Column {join_column.alias_or_name} does not exist in any of the tables." 494 ) 495 join_clause = functools.reduce( 496 lambda x, y: x & y, 497 [left_column == right_column for left_column, right_column in join_column_pairs], 498 ) 499 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 500 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 501 select_column_names = [ 502 column.alias_or_name 503 if not isinstance(column.expression.this, exp.Star) 504 else column.sql() 505 for column in self_columns + other_columns 506 ] 507 select_column_names = [ 508 column_name 509 for column_name in select_column_names 510 if column_name not in join_column_names 511 ] 512 select_column_names = join_column_names + select_column_names 513 else: 514 """ 515 Unique characteristics of join on expressions: 516 * There is no deduplication of the results. 517 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 518 """ 519 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 520 if len(join_columns) > 1: 521 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 522 join_clause = join_columns[0] 523 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 524 525 # Update the on expression with the actual join clause to replace the dummy one from before 526 join_expression.args["joins"][-1].set("on", join_clause.expression) 527 new_df = self.copy(expression=join_expression) 528 new_df.pending_join_hints.extend(self.pending_join_hints) 529 new_df.pending_hints.extend(other_df.pending_hints) 530 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 531 return new_df 532 533 @operation(Operation.ORDER_BY) 534 def orderBy( 535 self, 536 *cols: t.Union[str, Column], 537 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 538 ) -> DataFrame: 539 """ 540 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 541 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 542 is unlikely to come up. 543 """ 544 columns = self._ensure_and_normalize_cols(cols) 545 pre_ordered_col_indexes = [ 546 x 547 for x in [ 548 i if isinstance(col.expression, exp.Ordered) else None 549 for i, col in enumerate(columns) 550 ] 551 if x is not None 552 ] 553 if ascending is None: 554 ascending = [True] * len(columns) 555 elif not isinstance(ascending, list): 556 ascending = [ascending] * len(columns) 557 ascending = [bool(x) for i, x in enumerate(ascending)] 558 assert len(columns) == len( 559 ascending 560 ), "The length of items in ascending must equal the number of columns provided" 561 col_and_ascending = list(zip(columns, ascending)) 562 order_by_columns = [ 563 exp.Ordered(this=col.expression, desc=not asc) 564 if i not in pre_ordered_col_indexes 565 else columns[i].column_expression 566 for i, (col, asc) in enumerate(col_and_ascending) 567 ] 568 return self.copy(expression=self.expression.order_by(*order_by_columns)) 569 570 sort = orderBy 571 572 @operation(Operation.FROM) 573 def union(self, other: DataFrame) -> DataFrame: 574 return self._set_operation(exp.Union, other, False) 575 576 unionAll = union 577 578 @operation(Operation.FROM) 579 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 580 l_columns = self.columns 581 r_columns = other.columns 582 if not allowMissingColumns: 583 l_expressions = l_columns 584 r_expressions = l_columns 585 else: 586 l_expressions = [] 587 r_expressions = [] 588 r_columns_unused = copy(r_columns) 589 for l_column in l_columns: 590 l_expressions.append(l_column) 591 if l_column in r_columns: 592 r_expressions.append(l_column) 593 r_columns_unused.remove(l_column) 594 else: 595 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 596 for r_column in r_columns_unused: 597 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 598 r_expressions.append(r_column) 599 r_df = ( 600 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 601 ) 602 l_df = self.copy() 603 if allowMissingColumns: 604 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 605 return l_df._set_operation(exp.Union, r_df, False) 606 607 @operation(Operation.FROM) 608 def intersect(self, other: DataFrame) -> DataFrame: 609 return self._set_operation(exp.Intersect, other, True) 610 611 @operation(Operation.FROM) 612 def intersectAll(self, other: DataFrame) -> DataFrame: 613 return self._set_operation(exp.Intersect, other, False) 614 615 @operation(Operation.FROM) 616 def exceptAll(self, other: DataFrame) -> DataFrame: 617 return self._set_operation(exp.Except, other, False) 618 619 @operation(Operation.SELECT) 620 def distinct(self) -> DataFrame: 621 return self.copy(expression=self.expression.distinct()) 622 623 @operation(Operation.SELECT) 624 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 625 if not subset: 626 return self.distinct() 627 column_names = ensure_list(subset) 628 window = Window.partitionBy(*column_names).orderBy(*column_names) 629 return ( 630 self.copy() 631 .withColumn("row_num", F.row_number().over(window)) 632 .where(F.col("row_num") == F.lit(1)) 633 .drop("row_num") 634 ) 635 636 @operation(Operation.FROM) 637 def dropna( 638 self, 639 how: str = "any", 640 thresh: t.Optional[int] = None, 641 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 642 ) -> DataFrame: 643 minimum_non_null = thresh or 0 # will be determined later if thresh is null 644 new_df = self.copy() 645 all_columns = self._get_outer_select_columns(new_df.expression) 646 if subset: 647 null_check_columns = self._ensure_and_normalize_cols(subset) 648 else: 649 null_check_columns = all_columns 650 if thresh is None: 651 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 652 else: 653 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 654 if minimum_num_nulls > len(null_check_columns): 655 raise RuntimeError( 656 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 657 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 658 ) 659 if_null_checks = [ 660 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 661 ] 662 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 663 num_nulls = nulls_added_together.alias("num_nulls") 664 new_df = new_df.select(num_nulls, append=True) 665 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 666 final_df = filtered_df.select(*all_columns) 667 return final_df 668 669 @operation(Operation.FROM) 670 def fillna( 671 self, 672 value: t.Union[ColumnLiterals], 673 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 674 ) -> DataFrame: 675 """ 676 Functionality Difference: If you provide a value to replace a null and that type conflicts 677 with the type of the column then PySpark will just ignore your replacement. 678 This will try to cast them to be the same in some cases. So they won't always match. 679 Best to not mix types so make sure replacement is the same type as the column 680 681 Possibility for improvement: Use `typeof` function to get the type of the column 682 and check if it matches the type of the value provided. If not then make it null. 683 """ 684 from sqlglot.dataframe.sql.functions import lit 685 686 values = None 687 columns = None 688 new_df = self.copy() 689 all_columns = self._get_outer_select_columns(new_df.expression) 690 all_column_mapping = {column.alias_or_name: column for column in all_columns} 691 if isinstance(value, dict): 692 values = list(value.values()) 693 columns = self._ensure_and_normalize_cols(list(value)) 694 if not columns: 695 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 696 if not values: 697 values = [value] * len(columns) 698 value_columns = [lit(value) for value in values] 699 700 null_replacement_mapping = { 701 column.alias_or_name: ( 702 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 703 ) 704 for column, value in zip(columns, value_columns) 705 } 706 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 707 null_replacement_columns = [ 708 null_replacement_mapping[column.alias_or_name] for column in all_columns 709 ] 710 new_df = new_df.select(*null_replacement_columns) 711 return new_df 712 713 @operation(Operation.FROM) 714 def replace( 715 self, 716 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 717 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 718 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 719 ) -> DataFrame: 720 from sqlglot.dataframe.sql.functions import lit 721 722 old_values = None 723 new_df = self.copy() 724 all_columns = self._get_outer_select_columns(new_df.expression) 725 all_column_mapping = {column.alias_or_name: column for column in all_columns} 726 727 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 728 if isinstance(to_replace, dict): 729 old_values = list(to_replace) 730 new_values = list(to_replace.values()) 731 elif not old_values and isinstance(to_replace, list): 732 assert isinstance(value, list), "value must be a list since the replacements are a list" 733 assert len(to_replace) == len( 734 value 735 ), "the replacements and values must be the same length" 736 old_values = to_replace 737 new_values = value 738 else: 739 old_values = [to_replace] * len(columns) 740 new_values = [value] * len(columns) 741 old_values = [lit(value) for value in old_values] 742 new_values = [lit(value) for value in new_values] 743 744 replacement_mapping = {} 745 for column in columns: 746 expression = Column(None) 747 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 748 if i == 0: 749 expression = F.when(column == old_value, new_value) 750 else: 751 expression = expression.when(column == old_value, new_value) # type: ignore 752 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 753 column.expression.alias_or_name 754 ) 755 756 replacement_mapping = {**all_column_mapping, **replacement_mapping} 757 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 758 new_df = new_df.select(*replacement_columns) 759 return new_df 760 761 @operation(Operation.SELECT) 762 def withColumn(self, colName: str, col: Column) -> DataFrame: 763 col = self._ensure_and_normalize_col(col) 764 existing_col_names = self.expression.named_selects 765 existing_col_index = ( 766 existing_col_names.index(colName) if colName in existing_col_names else None 767 ) 768 if existing_col_index: 769 expression = self.expression.copy() 770 expression.expressions[existing_col_index] = col.expression 771 return self.copy(expression=expression) 772 return self.copy().select(col.alias(colName), append=True) 773 774 @operation(Operation.SELECT) 775 def withColumnRenamed(self, existing: str, new: str): 776 expression = self.expression.copy() 777 existing_columns = [ 778 expression 779 for expression in expression.expressions 780 if expression.alias_or_name == existing 781 ] 782 if not existing_columns: 783 raise ValueError("Tried to rename a column that doesn't exist") 784 for existing_column in existing_columns: 785 if isinstance(existing_column, exp.Column): 786 existing_column.replace(exp.alias_(existing_column, new)) 787 else: 788 existing_column.set("alias", exp.to_identifier(new)) 789 return self.copy(expression=expression) 790 791 @operation(Operation.SELECT) 792 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 793 all_columns = self._get_outer_select_columns(self.expression) 794 drop_cols = self._ensure_and_normalize_cols(cols) 795 new_columns = [ 796 col 797 for col in all_columns 798 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 799 ] 800 return self.copy().select(*new_columns, append=False) 801 802 @operation(Operation.LIMIT) 803 def limit(self, num: int) -> DataFrame: 804 return self.copy(expression=self.expression.limit(num)) 805 806 @operation(Operation.NO_OP) 807 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 808 parameter_list = ensure_list(parameters) 809 parameter_columns = ( 810 self._ensure_list_of_columns(parameter_list) 811 if parameters 812 else Column.ensure_cols([self.sequence_id]) 813 ) 814 return self._hint(name, parameter_columns) 815 816 @operation(Operation.NO_OP) 817 def repartition( 818 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 819 ) -> DataFrame: 820 num_partition_cols = self._ensure_list_of_columns(numPartitions) 821 columns = self._ensure_and_normalize_cols(cols) 822 args = num_partition_cols + columns 823 return self._hint("repartition", args) 824 825 @operation(Operation.NO_OP) 826 def coalesce(self, numPartitions: int) -> DataFrame: 827 num_partitions = Column.ensure_cols([numPartitions]) 828 return self._hint("coalesce", num_partitions) 829 830 @operation(Operation.NO_OP) 831 def cache(self) -> DataFrame: 832 return self._cache(storage_level="MEMORY_AND_DISK") 833 834 @operation(Operation.NO_OP) 835 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 836 """ 837 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 838 """ 839 return self._cache(storageLevel)
50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select()
301 def sql( 302 self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs 303 ) -> t.List[str]: 304 from sqlglot.dataframe.sql.session import SparkSession 305 306 if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: 307 logger.warning( 308 f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." 309 ) 310 df = self._resolve_pending_hints() 311 select_expressions = df._get_select_expressions() 312 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 313 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 314 for expression_type, select_expression in select_expressions: 315 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 316 if optimize: 317 quote_identifiers(select_expression) 318 select_expression = t.cast( 319 exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) 320 ) 321 select_expression = df._replace_cte_names_with_hashes(select_expression) 322 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 323 if expression_type == exp.Cache: 324 cache_table_name = df._create_hash_from_expression(select_expression) 325 cache_table = exp.to_table(cache_table_name) 326 original_alias_name = select_expression.args["cte_alias_name"] 327 328 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 329 cache_table_name 330 ) 331 sqlglot.schema.add_table( 332 cache_table_name, 333 { 334 expression.alias_or_name: expression.type.sql( 335 dialect=SparkSession().dialect 336 ) 337 for expression in select_expression.expressions 338 }, 339 dialect=SparkSession().dialect, 340 ) 341 cache_storage_level = select_expression.args["cache_storage_level"] 342 options = [ 343 exp.Literal.string("storageLevel"), 344 exp.Literal.string(cache_storage_level), 345 ] 346 expression = exp.Cache( 347 this=cache_table, expression=select_expression, lazy=True, options=options 348 ) 349 # We will drop the "view" if it exists before running the cache table 350 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 351 elif expression_type == exp.Create: 352 expression = df.output_expression_container.copy() 353 expression.set("expression", select_expression) 354 elif expression_type == exp.Insert: 355 expression = df.output_expression_container.copy() 356 select_without_ctes = select_expression.copy() 357 select_without_ctes.set("with", None) 358 expression.set("expression", select_without_ctes) 359 if select_expression.ctes: 360 expression.set("with", exp.With(expressions=select_expression.ctes)) 361 elif expression_type == exp.Select: 362 expression = select_expression 363 else: 364 raise ValueError(f"Invalid expression type: {expression_type}") 365 output_expressions.append(expression) 366 367 return [ 368 expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 369 for expression in output_expressions 370 ]
375 @operation(Operation.SELECT) 376 def select(self, *cols, **kwargs) -> DataFrame: 377 cols = self._ensure_and_normalize_cols(cols) 378 kwargs["append"] = kwargs.get("append", False) 379 if self.expression.args.get("joins"): 380 ambiguous_cols = [ 381 col 382 for col in cols 383 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 384 ] 385 if ambiguous_cols: 386 join_table_identifiers = [ 387 x.this for x in get_tables_from_expression_with_join(self.expression) 388 ] 389 cte_names_in_join = [x.this for x in join_table_identifiers] 390 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 391 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 392 # of Spark. 393 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 394 for ambiguous_col in ambiguous_cols: 395 ctes_with_column = [ 396 cte 397 for cte in self.expression.ctes 398 if cte.alias_or_name in cte_names_in_join 399 and ambiguous_col.alias_or_name in cte.this.named_selects 400 ] 401 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 402 # use the same CTE we used before 403 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 404 if cte: 405 resolved_column_position[ambiguous_col] += 1 406 else: 407 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 408 ambiguous_col.expression.set("table", cte.alias_or_name) 409 return self.copy( 410 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 411 )
413 @operation(Operation.NO_OP) 414 def alias(self, name: str, **kwargs) -> DataFrame: 415 new_sequence_id = self.spark._random_sequence_id 416 df = self.copy() 417 for join_hint in df.pending_join_hints: 418 for expression in join_hint.expressions: 419 if expression.alias_or_name == self.sequence_id: 420 expression.set("this", Column.ensure_col(new_sequence_id).expression) 421 df.spark._add_alias_to_mapping(name, new_sequence_id) 422 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
441 @operation(Operation.FROM) 442 def join( 443 self, 444 other_df: DataFrame, 445 on: t.Union[str, t.List[str], Column, t.List[Column]], 446 how: str = "inner", 447 **kwargs, 448 ) -> DataFrame: 449 other_df = other_df._convert_leaf_to_cte() 450 join_columns = self._ensure_list_of_columns(on) 451 # We will determine actual "join on" expression later so we don't provide it at first 452 join_expression = self.expression.join( 453 other_df.latest_cte_name, join_type=how.replace("_", " ") 454 ) 455 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 456 self_columns = self._get_outer_select_columns(join_expression) 457 other_columns = self._get_outer_select_columns(other_df) 458 # Determines the join clause and select columns to be used passed on what type of columns were provided for 459 # the join. The columns returned changes based on how the on expression is provided. 460 if isinstance(join_columns[0].expression, exp.Column): 461 """ 462 Unique characteristics of join on column names only: 463 * The column names are put at the front of the select list 464 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 465 """ 466 table_names = [ 467 table.alias_or_name 468 for table in get_tables_from_expression_with_join(join_expression) 469 ] 470 potential_ctes = [ 471 cte 472 for cte in join_expression.ctes 473 if cte.alias_or_name in table_names 474 and cte.alias_or_name != other_df.latest_cte_name 475 ] 476 # Determine the table to reference for the left side of the join by checking each of the left side 477 # tables and see if they have the column being referenced. 478 join_column_pairs = [] 479 for join_column in join_columns: 480 num_matching_ctes = 0 481 for cte in potential_ctes: 482 if join_column.alias_or_name in cte.this.named_selects: 483 left_column = join_column.copy().set_table_name(cte.alias_or_name) 484 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 485 join_column_pairs.append((left_column, right_column)) 486 num_matching_ctes += 1 487 if num_matching_ctes > 1: 488 raise ValueError( 489 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 490 ) 491 elif num_matching_ctes == 0: 492 raise ValueError( 493 f"Column {join_column.alias_or_name} does not exist in any of the tables." 494 ) 495 join_clause = functools.reduce( 496 lambda x, y: x & y, 497 [left_column == right_column for left_column, right_column in join_column_pairs], 498 ) 499 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 500 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 501 select_column_names = [ 502 column.alias_or_name 503 if not isinstance(column.expression.this, exp.Star) 504 else column.sql() 505 for column in self_columns + other_columns 506 ] 507 select_column_names = [ 508 column_name 509 for column_name in select_column_names 510 if column_name not in join_column_names 511 ] 512 select_column_names = join_column_names + select_column_names 513 else: 514 """ 515 Unique characteristics of join on expressions: 516 * There is no deduplication of the results. 517 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 518 """ 519 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 520 if len(join_columns) > 1: 521 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 522 join_clause = join_columns[0] 523 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 524 525 # Update the on expression with the actual join clause to replace the dummy one from before 526 join_expression.args["joins"][-1].set("on", join_clause.expression) 527 new_df = self.copy(expression=join_expression) 528 new_df.pending_join_hints.extend(self.pending_join_hints) 529 new_df.pending_hints.extend(other_df.pending_hints) 530 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 531 return new_df
533 @operation(Operation.ORDER_BY) 534 def orderBy( 535 self, 536 *cols: t.Union[str, Column], 537 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 538 ) -> DataFrame: 539 """ 540 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 541 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 542 is unlikely to come up. 543 """ 544 columns = self._ensure_and_normalize_cols(cols) 545 pre_ordered_col_indexes = [ 546 x 547 for x in [ 548 i if isinstance(col.expression, exp.Ordered) else None 549 for i, col in enumerate(columns) 550 ] 551 if x is not None 552 ] 553 if ascending is None: 554 ascending = [True] * len(columns) 555 elif not isinstance(ascending, list): 556 ascending = [ascending] * len(columns) 557 ascending = [bool(x) for i, x in enumerate(ascending)] 558 assert len(columns) == len( 559 ascending 560 ), "The length of items in ascending must equal the number of columns provided" 561 col_and_ascending = list(zip(columns, ascending)) 562 order_by_columns = [ 563 exp.Ordered(this=col.expression, desc=not asc) 564 if i not in pre_ordered_col_indexes 565 else columns[i].column_expression 566 for i, (col, asc) in enumerate(col_and_ascending) 567 ] 568 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
533 @operation(Operation.ORDER_BY) 534 def orderBy( 535 self, 536 *cols: t.Union[str, Column], 537 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 538 ) -> DataFrame: 539 """ 540 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 541 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 542 is unlikely to come up. 543 """ 544 columns = self._ensure_and_normalize_cols(cols) 545 pre_ordered_col_indexes = [ 546 x 547 for x in [ 548 i if isinstance(col.expression, exp.Ordered) else None 549 for i, col in enumerate(columns) 550 ] 551 if x is not None 552 ] 553 if ascending is None: 554 ascending = [True] * len(columns) 555 elif not isinstance(ascending, list): 556 ascending = [ascending] * len(columns) 557 ascending = [bool(x) for i, x in enumerate(ascending)] 558 assert len(columns) == len( 559 ascending 560 ), "The length of items in ascending must equal the number of columns provided" 561 col_and_ascending = list(zip(columns, ascending)) 562 order_by_columns = [ 563 exp.Ordered(this=col.expression, desc=not asc) 564 if i not in pre_ordered_col_indexes 565 else columns[i].column_expression 566 for i, (col, asc) in enumerate(col_and_ascending) 567 ] 568 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
578 @operation(Operation.FROM) 579 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 580 l_columns = self.columns 581 r_columns = other.columns 582 if not allowMissingColumns: 583 l_expressions = l_columns 584 r_expressions = l_columns 585 else: 586 l_expressions = [] 587 r_expressions = [] 588 r_columns_unused = copy(r_columns) 589 for l_column in l_columns: 590 l_expressions.append(l_column) 591 if l_column in r_columns: 592 r_expressions.append(l_column) 593 r_columns_unused.remove(l_column) 594 else: 595 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 596 for r_column in r_columns_unused: 597 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 598 r_expressions.append(r_column) 599 r_df = ( 600 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 601 ) 602 l_df = self.copy() 603 if allowMissingColumns: 604 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 605 return l_df._set_operation(exp.Union, r_df, False)
623 @operation(Operation.SELECT) 624 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 625 if not subset: 626 return self.distinct() 627 column_names = ensure_list(subset) 628 window = Window.partitionBy(*column_names).orderBy(*column_names) 629 return ( 630 self.copy() 631 .withColumn("row_num", F.row_number().over(window)) 632 .where(F.col("row_num") == F.lit(1)) 633 .drop("row_num") 634 )
636 @operation(Operation.FROM) 637 def dropna( 638 self, 639 how: str = "any", 640 thresh: t.Optional[int] = None, 641 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 642 ) -> DataFrame: 643 minimum_non_null = thresh or 0 # will be determined later if thresh is null 644 new_df = self.copy() 645 all_columns = self._get_outer_select_columns(new_df.expression) 646 if subset: 647 null_check_columns = self._ensure_and_normalize_cols(subset) 648 else: 649 null_check_columns = all_columns 650 if thresh is None: 651 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 652 else: 653 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 654 if minimum_num_nulls > len(null_check_columns): 655 raise RuntimeError( 656 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 657 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 658 ) 659 if_null_checks = [ 660 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 661 ] 662 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 663 num_nulls = nulls_added_together.alias("num_nulls") 664 new_df = new_df.select(num_nulls, append=True) 665 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 666 final_df = filtered_df.select(*all_columns) 667 return final_df
669 @operation(Operation.FROM) 670 def fillna( 671 self, 672 value: t.Union[ColumnLiterals], 673 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 674 ) -> DataFrame: 675 """ 676 Functionality Difference: If you provide a value to replace a null and that type conflicts 677 with the type of the column then PySpark will just ignore your replacement. 678 This will try to cast them to be the same in some cases. So they won't always match. 679 Best to not mix types so make sure replacement is the same type as the column 680 681 Possibility for improvement: Use `typeof` function to get the type of the column 682 and check if it matches the type of the value provided. If not then make it null. 683 """ 684 from sqlglot.dataframe.sql.functions import lit 685 686 values = None 687 columns = None 688 new_df = self.copy() 689 all_columns = self._get_outer_select_columns(new_df.expression) 690 all_column_mapping = {column.alias_or_name: column for column in all_columns} 691 if isinstance(value, dict): 692 values = list(value.values()) 693 columns = self._ensure_and_normalize_cols(list(value)) 694 if not columns: 695 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 696 if not values: 697 values = [value] * len(columns) 698 value_columns = [lit(value) for value in values] 699 700 null_replacement_mapping = { 701 column.alias_or_name: ( 702 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 703 ) 704 for column, value in zip(columns, value_columns) 705 } 706 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 707 null_replacement_columns = [ 708 null_replacement_mapping[column.alias_or_name] for column in all_columns 709 ] 710 new_df = new_df.select(*null_replacement_columns) 711 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
713 @operation(Operation.FROM) 714 def replace( 715 self, 716 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 717 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 718 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 719 ) -> DataFrame: 720 from sqlglot.dataframe.sql.functions import lit 721 722 old_values = None 723 new_df = self.copy() 724 all_columns = self._get_outer_select_columns(new_df.expression) 725 all_column_mapping = {column.alias_or_name: column for column in all_columns} 726 727 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 728 if isinstance(to_replace, dict): 729 old_values = list(to_replace) 730 new_values = list(to_replace.values()) 731 elif not old_values and isinstance(to_replace, list): 732 assert isinstance(value, list), "value must be a list since the replacements are a list" 733 assert len(to_replace) == len( 734 value 735 ), "the replacements and values must be the same length" 736 old_values = to_replace 737 new_values = value 738 else: 739 old_values = [to_replace] * len(columns) 740 new_values = [value] * len(columns) 741 old_values = [lit(value) for value in old_values] 742 new_values = [lit(value) for value in new_values] 743 744 replacement_mapping = {} 745 for column in columns: 746 expression = Column(None) 747 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 748 if i == 0: 749 expression = F.when(column == old_value, new_value) 750 else: 751 expression = expression.when(column == old_value, new_value) # type: ignore 752 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 753 column.expression.alias_or_name 754 ) 755 756 replacement_mapping = {**all_column_mapping, **replacement_mapping} 757 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 758 new_df = new_df.select(*replacement_columns) 759 return new_df
761 @operation(Operation.SELECT) 762 def withColumn(self, colName: str, col: Column) -> DataFrame: 763 col = self._ensure_and_normalize_col(col) 764 existing_col_names = self.expression.named_selects 765 existing_col_index = ( 766 existing_col_names.index(colName) if colName in existing_col_names else None 767 ) 768 if existing_col_index: 769 expression = self.expression.copy() 770 expression.expressions[existing_col_index] = col.expression 771 return self.copy(expression=expression) 772 return self.copy().select(col.alias(colName), append=True)
774 @operation(Operation.SELECT) 775 def withColumnRenamed(self, existing: str, new: str): 776 expression = self.expression.copy() 777 existing_columns = [ 778 expression 779 for expression in expression.expressions 780 if expression.alias_or_name == existing 781 ] 782 if not existing_columns: 783 raise ValueError("Tried to rename a column that doesn't exist") 784 for existing_column in existing_columns: 785 if isinstance(existing_column, exp.Column): 786 existing_column.replace(exp.alias_(existing_column, new)) 787 else: 788 existing_column.set("alias", exp.to_identifier(new)) 789 return self.copy(expression=expression)
791 @operation(Operation.SELECT) 792 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 793 all_columns = self._get_outer_select_columns(self.expression) 794 drop_cols = self._ensure_and_normalize_cols(cols) 795 new_columns = [ 796 col 797 for col in all_columns 798 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 799 ] 800 return self.copy().select(*new_columns, append=False)
806 @operation(Operation.NO_OP) 807 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 808 parameter_list = ensure_list(parameters) 809 parameter_columns = ( 810 self._ensure_list_of_columns(parameter_list) 811 if parameters 812 else Column.ensure_cols([self.sequence_id]) 813 ) 814 return self._hint(name, parameter_columns)
816 @operation(Operation.NO_OP) 817 def repartition( 818 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 819 ) -> DataFrame: 820 num_partition_cols = self._ensure_list_of_columns(numPartitions) 821 columns = self._ensure_and_normalize_cols(cols) 822 args = num_partition_cols + columns 823 return self._hint("repartition", args)
834 @operation(Operation.NO_OP) 835 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 836 """ 837 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 838 """ 839 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: [Column.ensure_col(x).expression for x in v] 149 if is_iterable(v) 150 else Column.ensure_col(v).expression 151 for k, v in kwargs.items() 152 if v is not None 153 } 154 new_expression = ( 155 callable_expression(**ensure_expression_values) 156 if ensured_column is None 157 else callable_expression( 158 this=ensured_column.column_expression, **ensure_expression_values 159 ) 160 ) 161 return Column(new_expression) 162 163 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 164 return Column( 165 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 166 ) 167 168 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 169 return Column( 170 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 171 ) 172 173 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 174 return Column(klass(this=self.column_expression, **kwargs)) 175 176 @property 177 def is_alias(self): 178 return isinstance(self.expression, exp.Alias) 179 180 @property 181 def is_column(self): 182 return isinstance(self.expression, exp.Column) 183 184 @property 185 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 186 return self.expression.unalias() 187 188 @property 189 def alias_or_name(self) -> str: 190 return self.expression.alias_or_name 191 192 @classmethod 193 def ensure_literal(cls, value) -> Column: 194 from sqlglot.dataframe.sql.functions import lit 195 196 if isinstance(value, cls): 197 value = value.expression 198 if not isinstance(value, exp.Literal): 199 return lit(value) 200 return Column(value) 201 202 def copy(self) -> Column: 203 return Column(self.expression.copy()) 204 205 def set_table_name(self, table_name: str, copy=False) -> Column: 206 expression = self.expression.copy() if copy else self.expression 207 expression.set("table", exp.to_identifier(table_name)) 208 return Column(expression) 209 210 def sql(self, **kwargs) -> str: 211 from sqlglot.dataframe.sql.session import SparkSession 212 213 return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 214 215 def alias(self, name: str) -> Column: 216 from sqlglot.dataframe.sql.session import SparkSession 217 218 dialect = SparkSession().dialect 219 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 220 new_expression = exp.alias_( 221 self.column_expression, 222 alias.this if isinstance(alias, exp.Column) else name, 223 dialect=dialect, 224 ) 225 return Column(new_expression) 226 227 def asc(self) -> Column: 228 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 229 return Column(new_expression) 230 231 def desc(self) -> Column: 232 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 233 return Column(new_expression) 234 235 asc_nulls_first = asc 236 237 def asc_nulls_last(self) -> Column: 238 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 239 return Column(new_expression) 240 241 def desc_nulls_first(self) -> Column: 242 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 243 return Column(new_expression) 244 245 desc_nulls_last = desc 246 247 def when(self, condition: Column, value: t.Any) -> Column: 248 from sqlglot.dataframe.sql.functions import when 249 250 column_with_if = when(condition, value) 251 if not isinstance(self.expression, exp.Case): 252 return column_with_if 253 new_column = self.copy() 254 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 255 return new_column 256 257 def otherwise(self, value: t.Any) -> Column: 258 from sqlglot.dataframe.sql.functions import lit 259 260 true_value = value if isinstance(value, Column) else lit(value) 261 new_column = self.copy() 262 new_column.expression.set("default", true_value.column_expression) 263 return new_column 264 265 def isNull(self) -> Column: 266 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 267 return Column(new_expression) 268 269 def isNotNull(self) -> Column: 270 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 271 return Column(new_expression) 272 273 def cast(self, dataType: t.Union[str, DataType]) -> Column: 274 """ 275 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 276 Sqlglot doesn't currently replicate this class so it only accepts a string 277 """ 278 from sqlglot.dataframe.sql.session import SparkSession 279 280 if isinstance(dataType, DataType): 281 dataType = dataType.simpleString() 282 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) 283 284 def startswith(self, value: t.Union[str, Column]) -> Column: 285 value = self._lit(value) if not isinstance(value, Column) else value 286 return self.invoke_anonymous_function(self, "STARTSWITH", value) 287 288 def endswith(self, value: t.Union[str, Column]) -> Column: 289 value = self._lit(value) if not isinstance(value, Column) else value 290 return self.invoke_anonymous_function(self, "ENDSWITH", value) 291 292 def rlike(self, regexp: str) -> Column: 293 return self.invoke_expression_over_column( 294 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 295 ) 296 297 def like(self, other: str): 298 return self.invoke_expression_over_column( 299 self, exp.Like, expression=self._lit(other).expression 300 ) 301 302 def ilike(self, other: str): 303 return self.invoke_expression_over_column( 304 self, exp.ILike, expression=self._lit(other).expression 305 ) 306 307 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 308 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 309 length = self._lit(length) if not isinstance(length, Column) else length 310 return Column.invoke_expression_over_column( 311 self, exp.Substring, start=startPos.expression, length=length.expression 312 ) 313 314 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 315 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 316 expressions = [self._lit(x).expression for x in columns] 317 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 318 319 def between( 320 self, 321 lowerBound: t.Union[ColumnOrLiteral], 322 upperBound: t.Union[ColumnOrLiteral], 323 ) -> Column: 324 lower_bound_exp = ( 325 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 326 ) 327 upper_bound_exp = ( 328 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 329 ) 330 return Column( 331 exp.Between( 332 this=self.column_expression, 333 low=lower_bound_exp.expression, 334 high=upper_bound_exp.expression, 335 ) 336 ) 337 338 def over(self, window: WindowSpec) -> Column: 339 window_expression = window.expression.copy() 340 window_expression.set("this", self.column_expression) 341 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: [Column.ensure_col(x).expression for x in v] 149 if is_iterable(v) 150 else Column.ensure_col(v).expression 151 for k, v in kwargs.items() 152 if v is not None 153 } 154 new_expression = ( 155 callable_expression(**ensure_expression_values) 156 if ensured_column is None 157 else callable_expression( 158 this=ensured_column.column_expression, **ensure_expression_values 159 ) 160 ) 161 return Column(new_expression)
215 def alias(self, name: str) -> Column: 216 from sqlglot.dataframe.sql.session import SparkSession 217 218 dialect = SparkSession().dialect 219 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 220 new_expression = exp.alias_( 221 self.column_expression, 222 alias.this if isinstance(alias, exp.Column) else name, 223 dialect=dialect, 224 ) 225 return Column(new_expression)
247 def when(self, condition: Column, value: t.Any) -> Column: 248 from sqlglot.dataframe.sql.functions import when 249 250 column_with_if = when(condition, value) 251 if not isinstance(self.expression, exp.Case): 252 return column_with_if 253 new_column = self.copy() 254 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 255 return new_column
273 def cast(self, dataType: t.Union[str, DataType]) -> Column: 274 """ 275 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 276 Sqlglot doesn't currently replicate this class so it only accepts a string 277 """ 278 from sqlglot.dataframe.sql.session import SparkSession 279 280 if isinstance(dataType, DataType): 281 dataType = dataType.simpleString() 282 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
307 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 308 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 309 length = self._lit(length) if not isinstance(length, Column) else length 310 return Column.invoke_expression_over_column( 311 self, exp.Substring, start=startPos.expression, length=length.expression 312 )
314 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 315 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 316 expressions = [self._lit(x).expression for x in columns] 317 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
319 def between( 320 self, 321 lowerBound: t.Union[ColumnOrLiteral], 322 upperBound: t.Union[ColumnOrLiteral], 323 ) -> Column: 324 lower_bound_exp = ( 325 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 326 ) 327 upper_bound_exp = ( 328 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 329 ) 330 return Column( 331 exp.Between( 332 this=self.column_expression, 333 low=lower_bound_exp.expression, 334 high=upper_bound_exp.expression, 335 ) 336 )
842class DataFrameNaFunctions: 843 def __init__(self, df: DataFrame): 844 self.df = df 845 846 def drop( 847 self, 848 how: str = "any", 849 thresh: t.Optional[int] = None, 850 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 851 ) -> DataFrame: 852 return self.df.dropna(how=how, thresh=thresh, subset=subset) 853 854 def fill( 855 self, 856 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 857 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 858 ) -> DataFrame: 859 return self.df.fillna(value=value, subset=subset) 860 861 def replace( 862 self, 863 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 864 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 865 subset: t.Optional[t.Union[str, t.List[str]]] = None, 866 ) -> DataFrame: 867 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
861 def replace( 862 self, 863 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 864 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 865 subset: t.Optional[t.Union[str, t.List[str]]] = None, 866 ) -> DataFrame: 867 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 from sqlglot.dataframe.sql.session import SparkSession 53 54 return self.expression.sql(dialect=SparkSession().dialect, **kwargs) 55 56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec 66 67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec 79 80 def _calc_start_end( 81 self, start: int, end: int 82 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 83 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 84 "start_side": None, 85 "end_side": None, 86 } 87 if start == Window.currentRow: 88 kwargs["start"] = "CURRENT ROW" 89 else: 90 kwargs = { 91 **kwargs, 92 **{ 93 "start_side": "PRECEDING", 94 "start": "UNBOUNDED" 95 if start <= Window.unboundedPreceding 96 else F.lit(start).expression, 97 }, 98 } 99 if end == Window.currentRow: 100 kwargs["end"] = "CURRENT ROW" 101 else: 102 kwargs = { 103 **kwargs, 104 **{ 105 "end_side": "FOLLOWING", 106 "end": "UNBOUNDED" 107 if end >= Window.unboundedFollowing 108 else F.lit(end).expression, 109 }, 110 } 111 return kwargs 112 113 def rowsBetween(self, start: int, end: int) -> WindowSpec: 114 window_spec = self.copy() 115 spec = self._calc_start_end(start, end) 116 spec["kind"] = "ROWS" 117 window_spec.expression.set( 118 "spec", 119 exp.WindowSpec( 120 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 121 ), 122 ) 123 return window_spec 124 125 def rangeBetween(self, start: int, end: int) -> WindowSpec: 126 window_spec = self.copy() 127 spec = self._calc_start_end(start, end) 128 spec["kind"] = "RANGE" 129 window_spec.expression.set( 130 "spec", 131 exp.WindowSpec( 132 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 133 ), 134 ) 135 return window_spec
56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec
67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec
113 def rowsBetween(self, start: int, end: int) -> WindowSpec: 114 window_spec = self.copy() 115 spec = self._calc_start_end(start, end) 116 spec["kind"] = "ROWS" 117 window_spec.expression.set( 118 "spec", 119 exp.WindowSpec( 120 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 121 ), 122 ) 123 return window_spec
125 def rangeBetween(self, start: int, end: int) -> WindowSpec: 126 window_spec = self.copy() 127 spec = self._calc_start_end(start, end) 128 spec["kind"] = "RANGE" 129 window_spec.expression.set( 130 "spec", 131 exp.WindowSpec( 132 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 133 ), 134 ) 135 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
44class DataFrameWriter: 45 def __init__( 46 self, 47 df: DataFrame, 48 spark: t.Optional[SparkSession] = None, 49 mode: t.Optional[str] = None, 50 by_name: bool = False, 51 ): 52 self._df = df 53 self._spark = spark or df.spark 54 self._mode = mode 55 self._by_name = by_name 56 57 def copy(self, **kwargs) -> DataFrameWriter: 58 return DataFrameWriter( 59 **{ 60 k[1:] if k.startswith("_") else k: v 61 for k, v in object_to_dict(self, **kwargs).items() 62 } 63 ) 64 65 def sql(self, **kwargs) -> t.List[str]: 66 return self._df.sql(**kwargs) 67 68 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 69 return self.copy(_mode=saveMode) 70 71 @property 72 def byName(self): 73 return self.copy(by_name=True) 74 75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df) 92 93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df)
93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))