sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
21class SparkSession: 22 DEFAULT_DIALECT = "spark" 23 _instance = None 24 25 def __init__(self): 26 if not hasattr(self, "known_ids"): 27 self.known_ids = set() 28 self.known_branch_ids = set() 29 self.known_sequence_ids = set() 30 self.name_to_sequence_id_mapping = defaultdict(list) 31 self.incrementing_id = 1 32 self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) 33 34 def __new__(cls, *args, **kwargs) -> SparkSession: 35 if cls._instance is None: 36 cls._instance = super().__new__(cls) 37 return cls._instance 38 39 @property 40 def read(self) -> DataFrameReader: 41 return DataFrameReader(self) 42 43 def table(self, tableName: str) -> DataFrame: 44 return self.read.table(tableName) 45 46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.tuple_( 75 *map( 76 lambda x: F.lit(x).expression, 77 row if not isinstance(row, dict) else row.values(), 78 ) 79 ) 80 for row in data 81 ] 82 83 sel_columns = [ 84 ( 85 F.col(name).cast(data_type).alias(name).expression 86 if data_type is not None 87 else F.col(name).expression 88 ) 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression) 107 108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df 126 127 @property 128 def _auto_incrementing_name(self) -> str: 129 name = f"a{self.incrementing_id}" 130 self.incrementing_id += 1 131 return name 132 133 @property 134 def _random_branch_id(self) -> str: 135 id = self._random_id 136 self.known_branch_ids.add(id) 137 return id 138 139 @property 140 def _random_sequence_id(self): 141 id = self._random_id 142 self.known_sequence_ids.add(id) 143 return id 144 145 @property 146 def _random_id(self) -> str: 147 id = "r" + uuid.uuid4().hex 148 self.known_ids.add(id) 149 return id 150 151 @property 152 def _join_hint_names(self) -> t.Set[str]: 153 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 154 155 def _add_alias_to_mapping(self, name: str, sequence_id: str): 156 self.name_to_sequence_id_mapping[name].append(sequence_id) 157 158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect) 187 return spark 188 189 @classproperty 190 def builder(cls) -> Builder: 191 return cls.Builder()
46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.tuple_( 75 *map( 76 lambda x: F.lit(x).expression, 77 row if not isinstance(row, dict) else row.values(), 78 ) 79 ) 80 for row in data 81 ] 82 83 sel_columns = [ 84 ( 85 F.col(name).cast(data_type).alias(name).expression 86 if data_type is not None 87 else F.col(name).expression 88 ) 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression)
108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df
158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect) 187 return spark
170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self
49class DataFrame: 50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select() 68 69 def __getattr__(self, column_name: str) -> Column: 70 return self[column_name] 71 72 def __getitem__(self, column_name: str) -> Column: 73 column_name = f"{self.branch_id}.{column_name}" 74 return Column(column_name) 75 76 def __copy__(self): 77 return self.copy() 78 79 @property 80 def sparkSession(self): 81 return self.spark 82 83 @property 84 def write(self): 85 return DataFrameWriter(self) 86 87 @property 88 def latest_cte_name(self) -> str: 89 if not self.expression.ctes: 90 from_exp = self.expression.args["from"] 91 if from_exp.alias_or_name: 92 return from_exp.alias_or_name 93 table_alias = from_exp.find(exp.TableAlias) 94 if not table_alias: 95 raise RuntimeError( 96 f"Could not find an alias name for this expression: {self.expression}" 97 ) 98 return table_alias.alias_or_name 99 return self.expression.ctes[-1].alias 100 101 @property 102 def pending_join_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 104 105 @property 106 def pending_partition_hints(self): 107 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 108 109 @property 110 def columns(self) -> t.List[str]: 111 return self.expression.named_selects 112 113 @property 114 def na(self) -> DataFrameNaFunctions: 115 return DataFrameNaFunctions(self) 116 117 def _replace_cte_names_with_hashes(self, expression: exp.Select): 118 replacement_mapping = {} 119 for cte in expression.ctes: 120 old_name_id = cte.args["alias"].this 121 new_hashed_id = exp.to_identifier( 122 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 123 ) 124 replacement_mapping[old_name_id] = new_hashed_id 125 expression = expression.transform(replace_id_value, replacement_mapping) 126 return expression 127 128 def _create_cte_from_expression( 129 self, 130 expression: exp.Expression, 131 branch_id: t.Optional[str] = None, 132 sequence_id: t.Optional[str] = None, 133 **kwargs, 134 ) -> t.Tuple[exp.CTE, str]: 135 name = self._create_hash_from_expression(expression) 136 expression_to_cte = expression.copy() 137 expression_to_cte.set("with", None) 138 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 139 cte.set("branch_id", branch_id or self.branch_id) 140 cte.set("sequence_id", sequence_id or self.sequence_id) 141 return cte, name 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 145 ... 146 147 @t.overload 148 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 149 ... 150 151 def _ensure_list_of_columns(self, cols): 152 return Column.ensure_cols(ensure_list(cols)) 153 154 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 155 cols = self._ensure_list_of_columns(cols) 156 normalize(self.spark, expression or self.expression, cols) 157 return cols 158 159 def _ensure_and_normalize_col(self, col): 160 col = Column.ensure_col(col) 161 normalize(self.spark, self.expression, col) 162 return col 163 164 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 165 df = self._resolve_pending_hints() 166 sequence_id = sequence_id or df.sequence_id 167 expression = df.expression.copy() 168 cte_expression, cte_name = df._create_cte_from_expression( 169 expression=expression, sequence_id=sequence_id 170 ) 171 new_expression = df._add_ctes_to_expression( 172 exp.Select(), expression.ctes + [cte_expression] 173 ) 174 sel_columns = df._get_outer_select_columns(cte_expression) 175 new_expression = new_expression.from_(cte_name).select( 176 *[x.alias_or_name for x in sel_columns] 177 ) 178 return df.copy(expression=new_expression, sequence_id=sequence_id) 179 180 def _resolve_pending_hints(self) -> DataFrame: 181 df = self.copy() 182 if not self.pending_hints: 183 return df 184 expression = df.expression 185 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 186 for hint in df.pending_partition_hints: 187 hint_expression.append("expressions", hint) 188 df.pending_hints.remove(hint) 189 190 join_aliases = { 191 join_table.alias_or_name 192 for join_table in get_tables_from_expression_with_join(expression) 193 } 194 if join_aliases: 195 for hint in df.pending_join_hints: 196 for sequence_id_expression in hint.expressions: 197 sequence_id_or_name = sequence_id_expression.alias_or_name 198 sequence_ids_to_match = [sequence_id_or_name] 199 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 200 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 201 sequence_id_or_name 202 ] 203 matching_ctes = [ 204 cte 205 for cte in reversed(expression.ctes) 206 if cte.args["sequence_id"] in sequence_ids_to_match 207 ] 208 for matching_cte in matching_ctes: 209 if matching_cte.alias_or_name in join_aliases: 210 sequence_id_expression.set("this", matching_cte.args["alias"].this) 211 df.pending_hints.remove(hint) 212 break 213 hint_expression.append("expressions", hint) 214 if hint_expression.expressions: 215 expression.set("hint", hint_expression) 216 return df 217 218 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 219 hint_name = hint_name.upper() 220 hint_expression = ( 221 exp.JoinHint( 222 this=hint_name, 223 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 224 ) 225 if hint_name in JOIN_HINTS 226 else exp.Anonymous( 227 this=hint_name, expressions=[parameter.expression for parameter in args] 228 ) 229 ) 230 new_df = self.copy() 231 new_df.pending_hints.append(hint_expression) 232 return new_df 233 234 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 235 other_df = other._convert_leaf_to_cte() 236 base_expression = self.expression.copy() 237 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 238 all_ctes = base_expression.ctes 239 other_df.expression.set("with", None) 240 base_expression.set("with", None) 241 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 242 operation.set("with", exp.With(expressions=all_ctes)) 243 return self.copy(expression=operation)._convert_leaf_to_cte() 244 245 def _cache(self, storage_level: str): 246 df = self._convert_leaf_to_cte() 247 df.expression.ctes[-1].set("cache_storage_level", storage_level) 248 return df 249 250 @classmethod 251 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 252 expression = expression.copy() 253 with_expression = expression.args.get("with") 254 if with_expression: 255 existing_ctes = with_expression.expressions 256 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 257 for cte in ctes: 258 if cte.alias_or_name not in existsing_cte_names: 259 existing_ctes.append(cte) 260 else: 261 existing_ctes = ctes 262 expression.set("with", exp.With(expressions=existing_ctes)) 263 return expression 264 265 @classmethod 266 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 267 expression = item.expression if isinstance(item, DataFrame) else item 268 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 269 270 @classmethod 271 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 272 from sqlglot.dataframe.sql.session import SparkSession 273 274 value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") 275 return f"t{zlib.crc32(value)}"[:6] 276 277 def _get_select_expressions( 278 self, 279 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 280 select_expressions: t.List[ 281 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 282 ] = [] 283 main_select_ctes: t.List[exp.CTE] = [] 284 for cte in self.expression.ctes: 285 cache_storage_level = cte.args.get("cache_storage_level") 286 if cache_storage_level: 287 select_expression = cte.this.copy() 288 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 289 select_expression.set("cte_alias_name", cte.alias_or_name) 290 select_expression.set("cache_storage_level", cache_storage_level) 291 select_expressions.append((exp.Cache, select_expression)) 292 else: 293 main_select_ctes.append(cte) 294 main_select = self.expression.copy() 295 if main_select_ctes: 296 main_select.set("with", exp.With(expressions=main_select_ctes)) 297 expression_select_pair = (type(self.output_expression_container), main_select) 298 select_expressions.append(expression_select_pair) # type: ignore 299 return select_expressions 300 301 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 302 from sqlglot.dataframe.sql.session import SparkSession 303 304 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 305 306 df = self._resolve_pending_hints() 307 select_expressions = df._get_select_expressions() 308 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 309 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 310 311 for expression_type, select_expression in select_expressions: 312 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 313 if optimize: 314 quote_identifiers(select_expression, dialect=dialect) 315 select_expression = t.cast( 316 exp.Select, optimize_func(select_expression, dialect=dialect) 317 ) 318 319 select_expression = df._replace_cte_names_with_hashes(select_expression) 320 321 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 322 if expression_type == exp.Cache: 323 cache_table_name = df._create_hash_from_expression(select_expression) 324 cache_table = exp.to_table(cache_table_name) 325 original_alias_name = select_expression.args["cte_alias_name"] 326 327 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 328 cache_table_name 329 ) 330 sqlglot.schema.add_table( 331 cache_table_name, 332 { 333 expression.alias_or_name: expression.type.sql(dialect=dialect) 334 for expression in select_expression.expressions 335 }, 336 dialect=dialect, 337 ) 338 339 cache_storage_level = select_expression.args["cache_storage_level"] 340 options = [ 341 exp.Literal.string("storageLevel"), 342 exp.Literal.string(cache_storage_level), 343 ] 344 expression = exp.Cache( 345 this=cache_table, expression=select_expression, lazy=True, options=options 346 ) 347 348 # We will drop the "view" if it exists before running the cache table 349 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 350 elif expression_type == exp.Create: 351 expression = df.output_expression_container.copy() 352 expression.set("expression", select_expression) 353 elif expression_type == exp.Insert: 354 expression = df.output_expression_container.copy() 355 select_without_ctes = select_expression.copy() 356 select_without_ctes.set("with", None) 357 expression.set("expression", select_without_ctes) 358 359 if select_expression.ctes: 360 expression.set("with", exp.With(expressions=select_expression.ctes)) 361 elif expression_type == exp.Select: 362 expression = select_expression 363 else: 364 raise ValueError(f"Invalid expression type: {expression_type}") 365 366 output_expressions.append(expression) 367 368 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] 369 370 def copy(self, **kwargs) -> DataFrame: 371 return DataFrame(**object_to_dict(self, **kwargs)) 372 373 @operation(Operation.SELECT) 374 def select(self, *cols, **kwargs) -> DataFrame: 375 cols = self._ensure_and_normalize_cols(cols) 376 kwargs["append"] = kwargs.get("append", False) 377 if self.expression.args.get("joins"): 378 ambiguous_cols = [ 379 col 380 for col in cols 381 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 382 ] 383 if ambiguous_cols: 384 join_table_identifiers = [ 385 x.this for x in get_tables_from_expression_with_join(self.expression) 386 ] 387 cte_names_in_join = [x.this for x in join_table_identifiers] 388 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 389 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 390 # of Spark. 391 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 392 for ambiguous_col in ambiguous_cols: 393 ctes_with_column = [ 394 cte 395 for cte in self.expression.ctes 396 if cte.alias_or_name in cte_names_in_join 397 and ambiguous_col.alias_or_name in cte.this.named_selects 398 ] 399 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 400 # use the same CTE we used before 401 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 402 if cte: 403 resolved_column_position[ambiguous_col] += 1 404 else: 405 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 406 ambiguous_col.expression.set("table", cte.alias_or_name) 407 return self.copy( 408 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 409 ) 410 411 @operation(Operation.NO_OP) 412 def alias(self, name: str, **kwargs) -> DataFrame: 413 new_sequence_id = self.spark._random_sequence_id 414 df = self.copy() 415 for join_hint in df.pending_join_hints: 416 for expression in join_hint.expressions: 417 if expression.alias_or_name == self.sequence_id: 418 expression.set("this", Column.ensure_col(new_sequence_id).expression) 419 df.spark._add_alias_to_mapping(name, new_sequence_id) 420 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 421 422 @operation(Operation.WHERE) 423 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 424 col = self._ensure_and_normalize_col(column) 425 return self.copy(expression=self.expression.where(col.expression)) 426 427 filter = where 428 429 @operation(Operation.GROUP_BY) 430 def groupBy(self, *cols, **kwargs) -> GroupedData: 431 columns = self._ensure_and_normalize_cols(cols) 432 return GroupedData(self, columns, self.last_op) 433 434 @operation(Operation.SELECT) 435 def agg(self, *exprs, **kwargs) -> DataFrame: 436 cols = self._ensure_and_normalize_cols(exprs) 437 return self.groupBy().agg(*cols) 438 439 @operation(Operation.FROM) 440 def join( 441 self, 442 other_df: DataFrame, 443 on: t.Union[str, t.List[str], Column, t.List[Column]], 444 how: str = "inner", 445 **kwargs, 446 ) -> DataFrame: 447 other_df = other_df._convert_leaf_to_cte() 448 join_columns = self._ensure_list_of_columns(on) 449 # We will determine actual "join on" expression later so we don't provide it at first 450 join_expression = self.expression.join( 451 other_df.latest_cte_name, join_type=how.replace("_", " ") 452 ) 453 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 454 self_columns = self._get_outer_select_columns(join_expression) 455 other_columns = self._get_outer_select_columns(other_df) 456 # Determines the join clause and select columns to be used passed on what type of columns were provided for 457 # the join. The columns returned changes based on how the on expression is provided. 458 if isinstance(join_columns[0].expression, exp.Column): 459 """ 460 Unique characteristics of join on column names only: 461 * The column names are put at the front of the select list 462 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 463 """ 464 table_names = [ 465 table.alias_or_name 466 for table in get_tables_from_expression_with_join(join_expression) 467 ] 468 potential_ctes = [ 469 cte 470 for cte in join_expression.ctes 471 if cte.alias_or_name in table_names 472 and cte.alias_or_name != other_df.latest_cte_name 473 ] 474 # Determine the table to reference for the left side of the join by checking each of the left side 475 # tables and see if they have the column being referenced. 476 join_column_pairs = [] 477 for join_column in join_columns: 478 num_matching_ctes = 0 479 for cte in potential_ctes: 480 if join_column.alias_or_name in cte.this.named_selects: 481 left_column = join_column.copy().set_table_name(cte.alias_or_name) 482 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 483 join_column_pairs.append((left_column, right_column)) 484 num_matching_ctes += 1 485 if num_matching_ctes > 1: 486 raise ValueError( 487 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 488 ) 489 elif num_matching_ctes == 0: 490 raise ValueError( 491 f"Column {join_column.alias_or_name} does not exist in any of the tables." 492 ) 493 join_clause = functools.reduce( 494 lambda x, y: x & y, 495 [left_column == right_column for left_column, right_column in join_column_pairs], 496 ) 497 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 498 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 499 select_column_names = [ 500 ( 501 column.alias_or_name 502 if not isinstance(column.expression.this, exp.Star) 503 else column.sql() 504 ) 505 for column in self_columns + other_columns 506 ] 507 select_column_names = [ 508 column_name 509 for column_name in select_column_names 510 if column_name not in join_column_names 511 ] 512 select_column_names = join_column_names + select_column_names 513 else: 514 """ 515 Unique characteristics of join on expressions: 516 * There is no deduplication of the results. 517 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 518 """ 519 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 520 if len(join_columns) > 1: 521 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 522 join_clause = join_columns[0] 523 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 524 525 # Update the on expression with the actual join clause to replace the dummy one from before 526 join_expression.args["joins"][-1].set("on", join_clause.expression) 527 new_df = self.copy(expression=join_expression) 528 new_df.pending_join_hints.extend(self.pending_join_hints) 529 new_df.pending_hints.extend(other_df.pending_hints) 530 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 531 return new_df 532 533 @operation(Operation.ORDER_BY) 534 def orderBy( 535 self, 536 *cols: t.Union[str, Column], 537 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 538 ) -> DataFrame: 539 """ 540 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 541 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 542 is unlikely to come up. 543 """ 544 columns = self._ensure_and_normalize_cols(cols) 545 pre_ordered_col_indexes = [ 546 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 547 ] 548 if ascending is None: 549 ascending = [True] * len(columns) 550 elif not isinstance(ascending, list): 551 ascending = [ascending] * len(columns) 552 ascending = [bool(x) for i, x in enumerate(ascending)] 553 assert len(columns) == len( 554 ascending 555 ), "The length of items in ascending must equal the number of columns provided" 556 col_and_ascending = list(zip(columns, ascending)) 557 order_by_columns = [ 558 ( 559 exp.Ordered(this=col.expression, desc=not asc) 560 if i not in pre_ordered_col_indexes 561 else columns[i].column_expression 562 ) 563 for i, (col, asc) in enumerate(col_and_ascending) 564 ] 565 return self.copy(expression=self.expression.order_by(*order_by_columns)) 566 567 sort = orderBy 568 569 @operation(Operation.FROM) 570 def union(self, other: DataFrame) -> DataFrame: 571 return self._set_operation(exp.Union, other, False) 572 573 unionAll = union 574 575 @operation(Operation.FROM) 576 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 577 l_columns = self.columns 578 r_columns = other.columns 579 if not allowMissingColumns: 580 l_expressions = l_columns 581 r_expressions = l_columns 582 else: 583 l_expressions = [] 584 r_expressions = [] 585 r_columns_unused = copy(r_columns) 586 for l_column in l_columns: 587 l_expressions.append(l_column) 588 if l_column in r_columns: 589 r_expressions.append(l_column) 590 r_columns_unused.remove(l_column) 591 else: 592 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 593 for r_column in r_columns_unused: 594 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 595 r_expressions.append(r_column) 596 r_df = ( 597 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 598 ) 599 l_df = self.copy() 600 if allowMissingColumns: 601 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 602 return l_df._set_operation(exp.Union, r_df, False) 603 604 @operation(Operation.FROM) 605 def intersect(self, other: DataFrame) -> DataFrame: 606 return self._set_operation(exp.Intersect, other, True) 607 608 @operation(Operation.FROM) 609 def intersectAll(self, other: DataFrame) -> DataFrame: 610 return self._set_operation(exp.Intersect, other, False) 611 612 @operation(Operation.FROM) 613 def exceptAll(self, other: DataFrame) -> DataFrame: 614 return self._set_operation(exp.Except, other, False) 615 616 @operation(Operation.SELECT) 617 def distinct(self) -> DataFrame: 618 return self.copy(expression=self.expression.distinct()) 619 620 @operation(Operation.SELECT) 621 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 622 if not subset: 623 return self.distinct() 624 column_names = ensure_list(subset) 625 window = Window.partitionBy(*column_names).orderBy(*column_names) 626 return ( 627 self.copy() 628 .withColumn("row_num", F.row_number().over(window)) 629 .where(F.col("row_num") == F.lit(1)) 630 .drop("row_num") 631 ) 632 633 @operation(Operation.FROM) 634 def dropna( 635 self, 636 how: str = "any", 637 thresh: t.Optional[int] = None, 638 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 639 ) -> DataFrame: 640 minimum_non_null = thresh or 0 # will be determined later if thresh is null 641 new_df = self.copy() 642 all_columns = self._get_outer_select_columns(new_df.expression) 643 if subset: 644 null_check_columns = self._ensure_and_normalize_cols(subset) 645 else: 646 null_check_columns = all_columns 647 if thresh is None: 648 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 649 else: 650 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 651 if minimum_num_nulls > len(null_check_columns): 652 raise RuntimeError( 653 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 654 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 655 ) 656 if_null_checks = [ 657 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 658 ] 659 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 660 num_nulls = nulls_added_together.alias("num_nulls") 661 new_df = new_df.select(num_nulls, append=True) 662 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 663 final_df = filtered_df.select(*all_columns) 664 return final_df 665 666 @operation(Operation.FROM) 667 def fillna( 668 self, 669 value: t.Union[ColumnLiterals], 670 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 671 ) -> DataFrame: 672 """ 673 Functionality Difference: If you provide a value to replace a null and that type conflicts 674 with the type of the column then PySpark will just ignore your replacement. 675 This will try to cast them to be the same in some cases. So they won't always match. 676 Best to not mix types so make sure replacement is the same type as the column 677 678 Possibility for improvement: Use `typeof` function to get the type of the column 679 and check if it matches the type of the value provided. If not then make it null. 680 """ 681 from sqlglot.dataframe.sql.functions import lit 682 683 values = None 684 columns = None 685 new_df = self.copy() 686 all_columns = self._get_outer_select_columns(new_df.expression) 687 all_column_mapping = {column.alias_or_name: column for column in all_columns} 688 if isinstance(value, dict): 689 values = list(value.values()) 690 columns = self._ensure_and_normalize_cols(list(value)) 691 if not columns: 692 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 693 if not values: 694 values = [value] * len(columns) 695 value_columns = [lit(value) for value in values] 696 697 null_replacement_mapping = { 698 column.alias_or_name: ( 699 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 700 ) 701 for column, value in zip(columns, value_columns) 702 } 703 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 704 null_replacement_columns = [ 705 null_replacement_mapping[column.alias_or_name] for column in all_columns 706 ] 707 new_df = new_df.select(*null_replacement_columns) 708 return new_df 709 710 @operation(Operation.FROM) 711 def replace( 712 self, 713 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 714 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 715 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 716 ) -> DataFrame: 717 from sqlglot.dataframe.sql.functions import lit 718 719 old_values = None 720 new_df = self.copy() 721 all_columns = self._get_outer_select_columns(new_df.expression) 722 all_column_mapping = {column.alias_or_name: column for column in all_columns} 723 724 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 725 if isinstance(to_replace, dict): 726 old_values = list(to_replace) 727 new_values = list(to_replace.values()) 728 elif not old_values and isinstance(to_replace, list): 729 assert isinstance(value, list), "value must be a list since the replacements are a list" 730 assert len(to_replace) == len( 731 value 732 ), "the replacements and values must be the same length" 733 old_values = to_replace 734 new_values = value 735 else: 736 old_values = [to_replace] * len(columns) 737 new_values = [value] * len(columns) 738 old_values = [lit(value) for value in old_values] 739 new_values = [lit(value) for value in new_values] 740 741 replacement_mapping = {} 742 for column in columns: 743 expression = Column(None) 744 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 745 if i == 0: 746 expression = F.when(column == old_value, new_value) 747 else: 748 expression = expression.when(column == old_value, new_value) # type: ignore 749 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 750 column.expression.alias_or_name 751 ) 752 753 replacement_mapping = {**all_column_mapping, **replacement_mapping} 754 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 755 new_df = new_df.select(*replacement_columns) 756 return new_df 757 758 @operation(Operation.SELECT) 759 def withColumn(self, colName: str, col: Column) -> DataFrame: 760 col = self._ensure_and_normalize_col(col) 761 existing_col_names = self.expression.named_selects 762 existing_col_index = ( 763 existing_col_names.index(colName) if colName in existing_col_names else None 764 ) 765 if existing_col_index: 766 expression = self.expression.copy() 767 expression.expressions[existing_col_index] = col.expression 768 return self.copy(expression=expression) 769 return self.copy().select(col.alias(colName), append=True) 770 771 @operation(Operation.SELECT) 772 def withColumnRenamed(self, existing: str, new: str): 773 expression = self.expression.copy() 774 existing_columns = [ 775 expression 776 for expression in expression.expressions 777 if expression.alias_or_name == existing 778 ] 779 if not existing_columns: 780 raise ValueError("Tried to rename a column that doesn't exist") 781 for existing_column in existing_columns: 782 if isinstance(existing_column, exp.Column): 783 existing_column.replace(exp.alias_(existing_column, new)) 784 else: 785 existing_column.set("alias", exp.to_identifier(new)) 786 return self.copy(expression=expression) 787 788 @operation(Operation.SELECT) 789 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 790 all_columns = self._get_outer_select_columns(self.expression) 791 drop_cols = self._ensure_and_normalize_cols(cols) 792 new_columns = [ 793 col 794 for col in all_columns 795 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 796 ] 797 return self.copy().select(*new_columns, append=False) 798 799 @operation(Operation.LIMIT) 800 def limit(self, num: int) -> DataFrame: 801 return self.copy(expression=self.expression.limit(num)) 802 803 @operation(Operation.NO_OP) 804 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 805 parameter_list = ensure_list(parameters) 806 parameter_columns = ( 807 self._ensure_list_of_columns(parameter_list) 808 if parameters 809 else Column.ensure_cols([self.sequence_id]) 810 ) 811 return self._hint(name, parameter_columns) 812 813 @operation(Operation.NO_OP) 814 def repartition( 815 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 816 ) -> DataFrame: 817 num_partition_cols = self._ensure_list_of_columns(numPartitions) 818 columns = self._ensure_and_normalize_cols(cols) 819 args = num_partition_cols + columns 820 return self._hint("repartition", args) 821 822 @operation(Operation.NO_OP) 823 def coalesce(self, numPartitions: int) -> DataFrame: 824 num_partitions = Column.ensure_cols([numPartitions]) 825 return self._hint("coalesce", num_partitions) 826 827 @operation(Operation.NO_OP) 828 def cache(self) -> DataFrame: 829 return self._cache(storage_level="MEMORY_AND_DISK") 830 831 @operation(Operation.NO_OP) 832 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 833 """ 834 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 835 """ 836 return self._cache(storageLevel)
50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select()
87 @property 88 def latest_cte_name(self) -> str: 89 if not self.expression.ctes: 90 from_exp = self.expression.args["from"] 91 if from_exp.alias_or_name: 92 return from_exp.alias_or_name 93 table_alias = from_exp.find(exp.TableAlias) 94 if not table_alias: 95 raise RuntimeError( 96 f"Could not find an alias name for this expression: {self.expression}" 97 ) 98 return table_alias.alias_or_name 99 return self.expression.ctes[-1].alias
301 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 302 from sqlglot.dataframe.sql.session import SparkSession 303 304 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 305 306 df = self._resolve_pending_hints() 307 select_expressions = df._get_select_expressions() 308 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 309 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 310 311 for expression_type, select_expression in select_expressions: 312 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 313 if optimize: 314 quote_identifiers(select_expression, dialect=dialect) 315 select_expression = t.cast( 316 exp.Select, optimize_func(select_expression, dialect=dialect) 317 ) 318 319 select_expression = df._replace_cte_names_with_hashes(select_expression) 320 321 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 322 if expression_type == exp.Cache: 323 cache_table_name = df._create_hash_from_expression(select_expression) 324 cache_table = exp.to_table(cache_table_name) 325 original_alias_name = select_expression.args["cte_alias_name"] 326 327 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 328 cache_table_name 329 ) 330 sqlglot.schema.add_table( 331 cache_table_name, 332 { 333 expression.alias_or_name: expression.type.sql(dialect=dialect) 334 for expression in select_expression.expressions 335 }, 336 dialect=dialect, 337 ) 338 339 cache_storage_level = select_expression.args["cache_storage_level"] 340 options = [ 341 exp.Literal.string("storageLevel"), 342 exp.Literal.string(cache_storage_level), 343 ] 344 expression = exp.Cache( 345 this=cache_table, expression=select_expression, lazy=True, options=options 346 ) 347 348 # We will drop the "view" if it exists before running the cache table 349 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 350 elif expression_type == exp.Create: 351 expression = df.output_expression_container.copy() 352 expression.set("expression", select_expression) 353 elif expression_type == exp.Insert: 354 expression = df.output_expression_container.copy() 355 select_without_ctes = select_expression.copy() 356 select_without_ctes.set("with", None) 357 expression.set("expression", select_without_ctes) 358 359 if select_expression.ctes: 360 expression.set("with", exp.With(expressions=select_expression.ctes)) 361 elif expression_type == exp.Select: 362 expression = select_expression 363 else: 364 raise ValueError(f"Invalid expression type: {expression_type}") 365 366 output_expressions.append(expression) 367 368 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
373 @operation(Operation.SELECT) 374 def select(self, *cols, **kwargs) -> DataFrame: 375 cols = self._ensure_and_normalize_cols(cols) 376 kwargs["append"] = kwargs.get("append", False) 377 if self.expression.args.get("joins"): 378 ambiguous_cols = [ 379 col 380 for col in cols 381 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 382 ] 383 if ambiguous_cols: 384 join_table_identifiers = [ 385 x.this for x in get_tables_from_expression_with_join(self.expression) 386 ] 387 cte_names_in_join = [x.this for x in join_table_identifiers] 388 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 389 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 390 # of Spark. 391 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 392 for ambiguous_col in ambiguous_cols: 393 ctes_with_column = [ 394 cte 395 for cte in self.expression.ctes 396 if cte.alias_or_name in cte_names_in_join 397 and ambiguous_col.alias_or_name in cte.this.named_selects 398 ] 399 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 400 # use the same CTE we used before 401 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 402 if cte: 403 resolved_column_position[ambiguous_col] += 1 404 else: 405 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 406 ambiguous_col.expression.set("table", cte.alias_or_name) 407 return self.copy( 408 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 409 )
411 @operation(Operation.NO_OP) 412 def alias(self, name: str, **kwargs) -> DataFrame: 413 new_sequence_id = self.spark._random_sequence_id 414 df = self.copy() 415 for join_hint in df.pending_join_hints: 416 for expression in join_hint.expressions: 417 if expression.alias_or_name == self.sequence_id: 418 expression.set("this", Column.ensure_col(new_sequence_id).expression) 419 df.spark._add_alias_to_mapping(name, new_sequence_id) 420 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
439 @operation(Operation.FROM) 440 def join( 441 self, 442 other_df: DataFrame, 443 on: t.Union[str, t.List[str], Column, t.List[Column]], 444 how: str = "inner", 445 **kwargs, 446 ) -> DataFrame: 447 other_df = other_df._convert_leaf_to_cte() 448 join_columns = self._ensure_list_of_columns(on) 449 # We will determine actual "join on" expression later so we don't provide it at first 450 join_expression = self.expression.join( 451 other_df.latest_cte_name, join_type=how.replace("_", " ") 452 ) 453 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 454 self_columns = self._get_outer_select_columns(join_expression) 455 other_columns = self._get_outer_select_columns(other_df) 456 # Determines the join clause and select columns to be used passed on what type of columns were provided for 457 # the join. The columns returned changes based on how the on expression is provided. 458 if isinstance(join_columns[0].expression, exp.Column): 459 """ 460 Unique characteristics of join on column names only: 461 * The column names are put at the front of the select list 462 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 463 """ 464 table_names = [ 465 table.alias_or_name 466 for table in get_tables_from_expression_with_join(join_expression) 467 ] 468 potential_ctes = [ 469 cte 470 for cte in join_expression.ctes 471 if cte.alias_or_name in table_names 472 and cte.alias_or_name != other_df.latest_cte_name 473 ] 474 # Determine the table to reference for the left side of the join by checking each of the left side 475 # tables and see if they have the column being referenced. 476 join_column_pairs = [] 477 for join_column in join_columns: 478 num_matching_ctes = 0 479 for cte in potential_ctes: 480 if join_column.alias_or_name in cte.this.named_selects: 481 left_column = join_column.copy().set_table_name(cte.alias_or_name) 482 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 483 join_column_pairs.append((left_column, right_column)) 484 num_matching_ctes += 1 485 if num_matching_ctes > 1: 486 raise ValueError( 487 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 488 ) 489 elif num_matching_ctes == 0: 490 raise ValueError( 491 f"Column {join_column.alias_or_name} does not exist in any of the tables." 492 ) 493 join_clause = functools.reduce( 494 lambda x, y: x & y, 495 [left_column == right_column for left_column, right_column in join_column_pairs], 496 ) 497 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 498 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 499 select_column_names = [ 500 ( 501 column.alias_or_name 502 if not isinstance(column.expression.this, exp.Star) 503 else column.sql() 504 ) 505 for column in self_columns + other_columns 506 ] 507 select_column_names = [ 508 column_name 509 for column_name in select_column_names 510 if column_name not in join_column_names 511 ] 512 select_column_names = join_column_names + select_column_names 513 else: 514 """ 515 Unique characteristics of join on expressions: 516 * There is no deduplication of the results. 517 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 518 """ 519 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 520 if len(join_columns) > 1: 521 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 522 join_clause = join_columns[0] 523 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 524 525 # Update the on expression with the actual join clause to replace the dummy one from before 526 join_expression.args["joins"][-1].set("on", join_clause.expression) 527 new_df = self.copy(expression=join_expression) 528 new_df.pending_join_hints.extend(self.pending_join_hints) 529 new_df.pending_hints.extend(other_df.pending_hints) 530 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 531 return new_df
533 @operation(Operation.ORDER_BY) 534 def orderBy( 535 self, 536 *cols: t.Union[str, Column], 537 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 538 ) -> DataFrame: 539 """ 540 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 541 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 542 is unlikely to come up. 543 """ 544 columns = self._ensure_and_normalize_cols(cols) 545 pre_ordered_col_indexes = [ 546 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 547 ] 548 if ascending is None: 549 ascending = [True] * len(columns) 550 elif not isinstance(ascending, list): 551 ascending = [ascending] * len(columns) 552 ascending = [bool(x) for i, x in enumerate(ascending)] 553 assert len(columns) == len( 554 ascending 555 ), "The length of items in ascending must equal the number of columns provided" 556 col_and_ascending = list(zip(columns, ascending)) 557 order_by_columns = [ 558 ( 559 exp.Ordered(this=col.expression, desc=not asc) 560 if i not in pre_ordered_col_indexes 561 else columns[i].column_expression 562 ) 563 for i, (col, asc) in enumerate(col_and_ascending) 564 ] 565 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
533 @operation(Operation.ORDER_BY) 534 def orderBy( 535 self, 536 *cols: t.Union[str, Column], 537 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 538 ) -> DataFrame: 539 """ 540 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 541 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 542 is unlikely to come up. 543 """ 544 columns = self._ensure_and_normalize_cols(cols) 545 pre_ordered_col_indexes = [ 546 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 547 ] 548 if ascending is None: 549 ascending = [True] * len(columns) 550 elif not isinstance(ascending, list): 551 ascending = [ascending] * len(columns) 552 ascending = [bool(x) for i, x in enumerate(ascending)] 553 assert len(columns) == len( 554 ascending 555 ), "The length of items in ascending must equal the number of columns provided" 556 col_and_ascending = list(zip(columns, ascending)) 557 order_by_columns = [ 558 ( 559 exp.Ordered(this=col.expression, desc=not asc) 560 if i not in pre_ordered_col_indexes 561 else columns[i].column_expression 562 ) 563 for i, (col, asc) in enumerate(col_and_ascending) 564 ] 565 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
575 @operation(Operation.FROM) 576 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 577 l_columns = self.columns 578 r_columns = other.columns 579 if not allowMissingColumns: 580 l_expressions = l_columns 581 r_expressions = l_columns 582 else: 583 l_expressions = [] 584 r_expressions = [] 585 r_columns_unused = copy(r_columns) 586 for l_column in l_columns: 587 l_expressions.append(l_column) 588 if l_column in r_columns: 589 r_expressions.append(l_column) 590 r_columns_unused.remove(l_column) 591 else: 592 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 593 for r_column in r_columns_unused: 594 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 595 r_expressions.append(r_column) 596 r_df = ( 597 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 598 ) 599 l_df = self.copy() 600 if allowMissingColumns: 601 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 602 return l_df._set_operation(exp.Union, r_df, False)
620 @operation(Operation.SELECT) 621 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 622 if not subset: 623 return self.distinct() 624 column_names = ensure_list(subset) 625 window = Window.partitionBy(*column_names).orderBy(*column_names) 626 return ( 627 self.copy() 628 .withColumn("row_num", F.row_number().over(window)) 629 .where(F.col("row_num") == F.lit(1)) 630 .drop("row_num") 631 )
633 @operation(Operation.FROM) 634 def dropna( 635 self, 636 how: str = "any", 637 thresh: t.Optional[int] = None, 638 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 639 ) -> DataFrame: 640 minimum_non_null = thresh or 0 # will be determined later if thresh is null 641 new_df = self.copy() 642 all_columns = self._get_outer_select_columns(new_df.expression) 643 if subset: 644 null_check_columns = self._ensure_and_normalize_cols(subset) 645 else: 646 null_check_columns = all_columns 647 if thresh is None: 648 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 649 else: 650 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 651 if minimum_num_nulls > len(null_check_columns): 652 raise RuntimeError( 653 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 654 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 655 ) 656 if_null_checks = [ 657 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 658 ] 659 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 660 num_nulls = nulls_added_together.alias("num_nulls") 661 new_df = new_df.select(num_nulls, append=True) 662 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 663 final_df = filtered_df.select(*all_columns) 664 return final_df
666 @operation(Operation.FROM) 667 def fillna( 668 self, 669 value: t.Union[ColumnLiterals], 670 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 671 ) -> DataFrame: 672 """ 673 Functionality Difference: If you provide a value to replace a null and that type conflicts 674 with the type of the column then PySpark will just ignore your replacement. 675 This will try to cast them to be the same in some cases. So they won't always match. 676 Best to not mix types so make sure replacement is the same type as the column 677 678 Possibility for improvement: Use `typeof` function to get the type of the column 679 and check if it matches the type of the value provided. If not then make it null. 680 """ 681 from sqlglot.dataframe.sql.functions import lit 682 683 values = None 684 columns = None 685 new_df = self.copy() 686 all_columns = self._get_outer_select_columns(new_df.expression) 687 all_column_mapping = {column.alias_or_name: column for column in all_columns} 688 if isinstance(value, dict): 689 values = list(value.values()) 690 columns = self._ensure_and_normalize_cols(list(value)) 691 if not columns: 692 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 693 if not values: 694 values = [value] * len(columns) 695 value_columns = [lit(value) for value in values] 696 697 null_replacement_mapping = { 698 column.alias_or_name: ( 699 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 700 ) 701 for column, value in zip(columns, value_columns) 702 } 703 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 704 null_replacement_columns = [ 705 null_replacement_mapping[column.alias_or_name] for column in all_columns 706 ] 707 new_df = new_df.select(*null_replacement_columns) 708 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
710 @operation(Operation.FROM) 711 def replace( 712 self, 713 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 714 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 715 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 716 ) -> DataFrame: 717 from sqlglot.dataframe.sql.functions import lit 718 719 old_values = None 720 new_df = self.copy() 721 all_columns = self._get_outer_select_columns(new_df.expression) 722 all_column_mapping = {column.alias_or_name: column for column in all_columns} 723 724 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 725 if isinstance(to_replace, dict): 726 old_values = list(to_replace) 727 new_values = list(to_replace.values()) 728 elif not old_values and isinstance(to_replace, list): 729 assert isinstance(value, list), "value must be a list since the replacements are a list" 730 assert len(to_replace) == len( 731 value 732 ), "the replacements and values must be the same length" 733 old_values = to_replace 734 new_values = value 735 else: 736 old_values = [to_replace] * len(columns) 737 new_values = [value] * len(columns) 738 old_values = [lit(value) for value in old_values] 739 new_values = [lit(value) for value in new_values] 740 741 replacement_mapping = {} 742 for column in columns: 743 expression = Column(None) 744 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 745 if i == 0: 746 expression = F.when(column == old_value, new_value) 747 else: 748 expression = expression.when(column == old_value, new_value) # type: ignore 749 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 750 column.expression.alias_or_name 751 ) 752 753 replacement_mapping = {**all_column_mapping, **replacement_mapping} 754 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 755 new_df = new_df.select(*replacement_columns) 756 return new_df
758 @operation(Operation.SELECT) 759 def withColumn(self, colName: str, col: Column) -> DataFrame: 760 col = self._ensure_and_normalize_col(col) 761 existing_col_names = self.expression.named_selects 762 existing_col_index = ( 763 existing_col_names.index(colName) if colName in existing_col_names else None 764 ) 765 if existing_col_index: 766 expression = self.expression.copy() 767 expression.expressions[existing_col_index] = col.expression 768 return self.copy(expression=expression) 769 return self.copy().select(col.alias(colName), append=True)
771 @operation(Operation.SELECT) 772 def withColumnRenamed(self, existing: str, new: str): 773 expression = self.expression.copy() 774 existing_columns = [ 775 expression 776 for expression in expression.expressions 777 if expression.alias_or_name == existing 778 ] 779 if not existing_columns: 780 raise ValueError("Tried to rename a column that doesn't exist") 781 for existing_column in existing_columns: 782 if isinstance(existing_column, exp.Column): 783 existing_column.replace(exp.alias_(existing_column, new)) 784 else: 785 existing_column.set("alias", exp.to_identifier(new)) 786 return self.copy(expression=expression)
788 @operation(Operation.SELECT) 789 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 790 all_columns = self._get_outer_select_columns(self.expression) 791 drop_cols = self._ensure_and_normalize_cols(cols) 792 new_columns = [ 793 col 794 for col in all_columns 795 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 796 ] 797 return self.copy().select(*new_columns, append=False)
803 @operation(Operation.NO_OP) 804 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 805 parameter_list = ensure_list(parameters) 806 parameter_columns = ( 807 self._ensure_list_of_columns(parameter_list) 808 if parameters 809 else Column.ensure_cols([self.sequence_id]) 810 ) 811 return self._hint(name, parameter_columns)
813 @operation(Operation.NO_OP) 814 def repartition( 815 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 816 ) -> DataFrame: 817 num_partition_cols = self._ensure_list_of_columns(numPartitions) 818 columns = self._ensure_and_normalize_cols(cols) 819 args = num_partition_cols + columns 820 return self._hint("repartition", args)
831 @operation(Operation.NO_OP) 832 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 833 """ 834 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 835 """ 836 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression) 164 165 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 166 return Column( 167 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 168 ) 169 170 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 171 return Column( 172 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 173 ) 174 175 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 176 return Column(klass(this=self.column_expression, **kwargs)) 177 178 @property 179 def is_alias(self): 180 return isinstance(self.expression, exp.Alias) 181 182 @property 183 def is_column(self): 184 return isinstance(self.expression, exp.Column) 185 186 @property 187 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 188 return self.expression.unalias() 189 190 @property 191 def alias_or_name(self) -> str: 192 return self.expression.alias_or_name 193 194 @classmethod 195 def ensure_literal(cls, value) -> Column: 196 from sqlglot.dataframe.sql.functions import lit 197 198 if isinstance(value, cls): 199 value = value.expression 200 if not isinstance(value, exp.Literal): 201 return lit(value) 202 return Column(value) 203 204 def copy(self) -> Column: 205 return Column(self.expression.copy()) 206 207 def set_table_name(self, table_name: str, copy=False) -> Column: 208 expression = self.expression.copy() if copy else self.expression 209 expression.set("table", exp.to_identifier(table_name)) 210 return Column(expression) 211 212 def sql(self, **kwargs) -> str: 213 from sqlglot.dataframe.sql.session import SparkSession 214 215 return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 216 217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression) 228 229 def asc(self) -> Column: 230 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 231 return Column(new_expression) 232 233 def desc(self) -> Column: 234 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 235 return Column(new_expression) 236 237 asc_nulls_first = asc 238 239 def asc_nulls_last(self) -> Column: 240 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 241 return Column(new_expression) 242 243 def desc_nulls_first(self) -> Column: 244 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 245 return Column(new_expression) 246 247 desc_nulls_last = desc 248 249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column 258 259 def otherwise(self, value: t.Any) -> Column: 260 from sqlglot.dataframe.sql.functions import lit 261 262 true_value = value if isinstance(value, Column) else lit(value) 263 new_column = self.copy() 264 new_column.expression.set("default", true_value.column_expression) 265 return new_column 266 267 def isNull(self) -> Column: 268 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 269 return Column(new_expression) 270 271 def isNotNull(self) -> Column: 272 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 273 return Column(new_expression) 274 275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) 285 286 def startswith(self, value: t.Union[str, Column]) -> Column: 287 value = self._lit(value) if not isinstance(value, Column) else value 288 return self.invoke_anonymous_function(self, "STARTSWITH", value) 289 290 def endswith(self, value: t.Union[str, Column]) -> Column: 291 value = self._lit(value) if not isinstance(value, Column) else value 292 return self.invoke_anonymous_function(self, "ENDSWITH", value) 293 294 def rlike(self, regexp: str) -> Column: 295 return self.invoke_expression_over_column( 296 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 297 ) 298 299 def like(self, other: str): 300 return self.invoke_expression_over_column( 301 self, exp.Like, expression=self._lit(other).expression 302 ) 303 304 def ilike(self, other: str): 305 return self.invoke_expression_over_column( 306 self, exp.ILike, expression=self._lit(other).expression 307 ) 308 309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 ) 315 316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 320 321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 ) 339 340 def over(self, window: WindowSpec) -> Column: 341 window_expression = window.expression.copy() 342 window_expression.set("this", self.column_expression) 343 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression)
217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression)
249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column
275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 )
316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 )
839class DataFrameNaFunctions: 840 def __init__(self, df: DataFrame): 841 self.df = df 842 843 def drop( 844 self, 845 how: str = "any", 846 thresh: t.Optional[int] = None, 847 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 848 ) -> DataFrame: 849 return self.df.dropna(how=how, thresh=thresh, subset=subset) 850 851 def fill( 852 self, 853 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 854 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 855 ) -> DataFrame: 856 return self.df.fillna(value=value, subset=subset) 857 858 def replace( 859 self, 860 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 861 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 862 subset: t.Optional[t.Union[str, t.List[str]]] = None, 863 ) -> DataFrame: 864 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
858 def replace( 859 self, 860 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 861 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 862 subset: t.Optional[t.Union[str, t.List[str]]] = None, 863 ) -> DataFrame: 864 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 from sqlglot.dataframe.sql.session import SparkSession 53 54 return self.expression.sql(dialect=SparkSession().dialect, **kwargs) 55 56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec 66 67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec 79 80 def _calc_start_end( 81 self, start: int, end: int 82 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 83 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 84 "start_side": None, 85 "end_side": None, 86 } 87 if start == Window.currentRow: 88 kwargs["start"] = "CURRENT ROW" 89 else: 90 kwargs = { 91 **kwargs, 92 **{ 93 "start_side": "PRECEDING", 94 "start": ( 95 "UNBOUNDED" 96 if start <= Window.unboundedPreceding 97 else F.lit(start).expression 98 ), 99 }, 100 } 101 if end == Window.currentRow: 102 kwargs["end"] = "CURRENT ROW" 103 else: 104 kwargs = { 105 **kwargs, 106 **{ 107 "end_side": "FOLLOWING", 108 "end": ( 109 "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression 110 ), 111 }, 112 } 113 return kwargs 114 115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec 126 127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec
67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec
115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec
127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
44class DataFrameWriter: 45 def __init__( 46 self, 47 df: DataFrame, 48 spark: t.Optional[SparkSession] = None, 49 mode: t.Optional[str] = None, 50 by_name: bool = False, 51 ): 52 self._df = df 53 self._spark = spark or df.spark 54 self._mode = mode 55 self._by_name = by_name 56 57 def copy(self, **kwargs) -> DataFrameWriter: 58 return DataFrameWriter( 59 **{ 60 k[1:] if k.startswith("_") else k: v 61 for k, v in object_to_dict(self, **kwargs).items() 62 } 63 ) 64 65 def sql(self, **kwargs) -> t.List[str]: 66 return self._df.sql(**kwargs) 67 68 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 69 return self.copy(_mode=saveMode) 70 71 @property 72 def byName(self): 73 return self.copy(by_name=True) 74 75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df) 92 93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df)
93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))