sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
21class SparkSession: 22 DEFAULT_DIALECT = "spark" 23 _instance = None 24 25 def __init__(self): 26 if not hasattr(self, "known_ids"): 27 self.known_ids = set() 28 self.known_branch_ids = set() 29 self.known_sequence_ids = set() 30 self.name_to_sequence_id_mapping = defaultdict(list) 31 self.incrementing_id = 1 32 self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) 33 34 def __new__(cls, *args, **kwargs) -> SparkSession: 35 if cls._instance is None: 36 cls._instance = super().__new__(cls) 37 return cls._instance 38 39 @property 40 def read(self) -> DataFrameReader: 41 return DataFrameReader(self) 42 43 def table(self, tableName: str) -> DataFrame: 44 return self.read.table(tableName) 45 46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.Tuple( 75 expressions=list( 76 map( 77 lambda x: F.lit(x).expression, 78 row if not isinstance(row, dict) else row.values(), 79 ) 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 F.col(name).cast(data_type).alias(name).expression 87 if data_type is not None 88 else F.col(name).expression 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression) 107 108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df 126 127 @property 128 def _auto_incrementing_name(self) -> str: 129 name = f"a{self.incrementing_id}" 130 self.incrementing_id += 1 131 return name 132 133 @property 134 def _random_branch_id(self) -> str: 135 id = self._random_id 136 self.known_branch_ids.add(id) 137 return id 138 139 @property 140 def _random_sequence_id(self): 141 id = self._random_id 142 self.known_sequence_ids.add(id) 143 return id 144 145 @property 146 def _random_id(self) -> str: 147 id = "r" + uuid.uuid4().hex 148 self.known_ids.add(id) 149 return id 150 151 @property 152 def _join_hint_names(self) -> t.Set[str]: 153 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 154 155 def _add_alias_to_mapping(self, name: str, sequence_id: str): 156 self.name_to_sequence_id_mapping[name].append(sequence_id) 157 158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect) 187 return spark 188 189 @classproperty 190 def builder(cls) -> Builder: 191 return cls.Builder()
46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.Tuple( 75 expressions=list( 76 map( 77 lambda x: F.lit(x).expression, 78 row if not isinstance(row, dict) else row.values(), 79 ) 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 F.col(name).cast(data_type).alias(name).expression 87 if data_type is not None 88 else F.col(name).expression 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression)
108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df
158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect) 187 return spark
170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self
49class DataFrame: 50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select() 68 69 def __getattr__(self, column_name: str) -> Column: 70 return self[column_name] 71 72 def __getitem__(self, column_name: str) -> Column: 73 column_name = f"{self.branch_id}.{column_name}" 74 return Column(column_name) 75 76 def __copy__(self): 77 return self.copy() 78 79 @property 80 def sparkSession(self): 81 return self.spark 82 83 @property 84 def write(self): 85 return DataFrameWriter(self) 86 87 @property 88 def latest_cte_name(self) -> str: 89 if not self.expression.ctes: 90 from_exp = self.expression.args["from"] 91 if from_exp.alias_or_name: 92 return from_exp.alias_or_name 93 table_alias = from_exp.find(exp.TableAlias) 94 if not table_alias: 95 raise RuntimeError( 96 f"Could not find an alias name for this expression: {self.expression}" 97 ) 98 return table_alias.alias_or_name 99 return self.expression.ctes[-1].alias 100 101 @property 102 def pending_join_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 104 105 @property 106 def pending_partition_hints(self): 107 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 108 109 @property 110 def columns(self) -> t.List[str]: 111 return self.expression.named_selects 112 113 @property 114 def na(self) -> DataFrameNaFunctions: 115 return DataFrameNaFunctions(self) 116 117 def _replace_cte_names_with_hashes(self, expression: exp.Select): 118 replacement_mapping = {} 119 for cte in expression.ctes: 120 old_name_id = cte.args["alias"].this 121 new_hashed_id = exp.to_identifier( 122 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 123 ) 124 replacement_mapping[old_name_id] = new_hashed_id 125 expression = expression.transform(replace_id_value, replacement_mapping) 126 return expression 127 128 def _create_cte_from_expression( 129 self, 130 expression: exp.Expression, 131 branch_id: t.Optional[str] = None, 132 sequence_id: t.Optional[str] = None, 133 **kwargs, 134 ) -> t.Tuple[exp.CTE, str]: 135 name = self._create_hash_from_expression(expression) 136 expression_to_cte = expression.copy() 137 expression_to_cte.set("with", None) 138 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 139 cte.set("branch_id", branch_id or self.branch_id) 140 cte.set("sequence_id", sequence_id or self.sequence_id) 141 return cte, name 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: 145 ... 146 147 @t.overload 148 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: 149 ... 150 151 def _ensure_list_of_columns(self, cols): 152 return Column.ensure_cols(ensure_list(cols)) 153 154 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 155 cols = self._ensure_list_of_columns(cols) 156 normalize(self.spark, expression or self.expression, cols) 157 return cols 158 159 def _ensure_and_normalize_col(self, col): 160 col = Column.ensure_col(col) 161 normalize(self.spark, self.expression, col) 162 return col 163 164 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 165 df = self._resolve_pending_hints() 166 sequence_id = sequence_id or df.sequence_id 167 expression = df.expression.copy() 168 cte_expression, cte_name = df._create_cte_from_expression( 169 expression=expression, sequence_id=sequence_id 170 ) 171 new_expression = df._add_ctes_to_expression( 172 exp.Select(), expression.ctes + [cte_expression] 173 ) 174 sel_columns = df._get_outer_select_columns(cte_expression) 175 new_expression = new_expression.from_(cte_name).select( 176 *[x.alias_or_name for x in sel_columns] 177 ) 178 return df.copy(expression=new_expression, sequence_id=sequence_id) 179 180 def _resolve_pending_hints(self) -> DataFrame: 181 df = self.copy() 182 if not self.pending_hints: 183 return df 184 expression = df.expression 185 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 186 for hint in df.pending_partition_hints: 187 hint_expression.append("expressions", hint) 188 df.pending_hints.remove(hint) 189 190 join_aliases = { 191 join_table.alias_or_name 192 for join_table in get_tables_from_expression_with_join(expression) 193 } 194 if join_aliases: 195 for hint in df.pending_join_hints: 196 for sequence_id_expression in hint.expressions: 197 sequence_id_or_name = sequence_id_expression.alias_or_name 198 sequence_ids_to_match = [sequence_id_or_name] 199 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 200 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 201 sequence_id_or_name 202 ] 203 matching_ctes = [ 204 cte 205 for cte in reversed(expression.ctes) 206 if cte.args["sequence_id"] in sequence_ids_to_match 207 ] 208 for matching_cte in matching_ctes: 209 if matching_cte.alias_or_name in join_aliases: 210 sequence_id_expression.set("this", matching_cte.args["alias"].this) 211 df.pending_hints.remove(hint) 212 break 213 hint_expression.append("expressions", hint) 214 if hint_expression.expressions: 215 expression.set("hint", hint_expression) 216 return df 217 218 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 219 hint_name = hint_name.upper() 220 hint_expression = ( 221 exp.JoinHint( 222 this=hint_name, 223 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 224 ) 225 if hint_name in JOIN_HINTS 226 else exp.Anonymous( 227 this=hint_name, expressions=[parameter.expression for parameter in args] 228 ) 229 ) 230 new_df = self.copy() 231 new_df.pending_hints.append(hint_expression) 232 return new_df 233 234 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 235 other_df = other._convert_leaf_to_cte() 236 base_expression = self.expression.copy() 237 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 238 all_ctes = base_expression.ctes 239 other_df.expression.set("with", None) 240 base_expression.set("with", None) 241 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 242 operation.set("with", exp.With(expressions=all_ctes)) 243 return self.copy(expression=operation)._convert_leaf_to_cte() 244 245 def _cache(self, storage_level: str): 246 df = self._convert_leaf_to_cte() 247 df.expression.ctes[-1].set("cache_storage_level", storage_level) 248 return df 249 250 @classmethod 251 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 252 expression = expression.copy() 253 with_expression = expression.args.get("with") 254 if with_expression: 255 existing_ctes = with_expression.expressions 256 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 257 for cte in ctes: 258 if cte.alias_or_name not in existsing_cte_names: 259 existing_ctes.append(cte) 260 else: 261 existing_ctes = ctes 262 expression.set("with", exp.With(expressions=existing_ctes)) 263 return expression 264 265 @classmethod 266 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 267 expression = item.expression if isinstance(item, DataFrame) else item 268 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 269 270 @classmethod 271 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 272 from sqlglot.dataframe.sql.session import SparkSession 273 274 value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") 275 return f"t{zlib.crc32(value)}"[:6] 276 277 def _get_select_expressions( 278 self, 279 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 280 select_expressions: t.List[ 281 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 282 ] = [] 283 main_select_ctes: t.List[exp.CTE] = [] 284 for cte in self.expression.ctes: 285 cache_storage_level = cte.args.get("cache_storage_level") 286 if cache_storage_level: 287 select_expression = cte.this.copy() 288 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 289 select_expression.set("cte_alias_name", cte.alias_or_name) 290 select_expression.set("cache_storage_level", cache_storage_level) 291 select_expressions.append((exp.Cache, select_expression)) 292 else: 293 main_select_ctes.append(cte) 294 main_select = self.expression.copy() 295 if main_select_ctes: 296 main_select.set("with", exp.With(expressions=main_select_ctes)) 297 expression_select_pair = (type(self.output_expression_container), main_select) 298 select_expressions.append(expression_select_pair) # type: ignore 299 return select_expressions 300 301 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 302 from sqlglot.dataframe.sql.session import SparkSession 303 304 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 305 306 df = self._resolve_pending_hints() 307 select_expressions = df._get_select_expressions() 308 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 309 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 310 311 for expression_type, select_expression in select_expressions: 312 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 313 if optimize: 314 quote_identifiers(select_expression, dialect=dialect) 315 select_expression = t.cast( 316 exp.Select, optimize_func(select_expression, dialect=dialect) 317 ) 318 319 select_expression = df._replace_cte_names_with_hashes(select_expression) 320 321 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 322 if expression_type == exp.Cache: 323 cache_table_name = df._create_hash_from_expression(select_expression) 324 cache_table = exp.to_table(cache_table_name) 325 original_alias_name = select_expression.args["cte_alias_name"] 326 327 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 328 cache_table_name 329 ) 330 sqlglot.schema.add_table( 331 cache_table_name, 332 { 333 expression.alias_or_name: expression.type.sql(dialect=dialect) 334 for expression in select_expression.expressions 335 }, 336 dialect=dialect, 337 ) 338 339 cache_storage_level = select_expression.args["cache_storage_level"] 340 options = [ 341 exp.Literal.string("storageLevel"), 342 exp.Literal.string(cache_storage_level), 343 ] 344 expression = exp.Cache( 345 this=cache_table, expression=select_expression, lazy=True, options=options 346 ) 347 348 # We will drop the "view" if it exists before running the cache table 349 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 350 elif expression_type == exp.Create: 351 expression = df.output_expression_container.copy() 352 expression.set("expression", select_expression) 353 elif expression_type == exp.Insert: 354 expression = df.output_expression_container.copy() 355 select_without_ctes = select_expression.copy() 356 select_without_ctes.set("with", None) 357 expression.set("expression", select_without_ctes) 358 359 if select_expression.ctes: 360 expression.set("with", exp.With(expressions=select_expression.ctes)) 361 elif expression_type == exp.Select: 362 expression = select_expression 363 else: 364 raise ValueError(f"Invalid expression type: {expression_type}") 365 366 output_expressions.append(expression) 367 368 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] 369 370 def copy(self, **kwargs) -> DataFrame: 371 return DataFrame(**object_to_dict(self, **kwargs)) 372 373 @operation(Operation.SELECT) 374 def select(self, *cols, **kwargs) -> DataFrame: 375 cols = self._ensure_and_normalize_cols(cols) 376 kwargs["append"] = kwargs.get("append", False) 377 if self.expression.args.get("joins"): 378 ambiguous_cols = [ 379 col 380 for col in cols 381 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 382 ] 383 if ambiguous_cols: 384 join_table_identifiers = [ 385 x.this for x in get_tables_from_expression_with_join(self.expression) 386 ] 387 cte_names_in_join = [x.this for x in join_table_identifiers] 388 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 389 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 390 # of Spark. 391 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 392 for ambiguous_col in ambiguous_cols: 393 ctes_with_column = [ 394 cte 395 for cte in self.expression.ctes 396 if cte.alias_or_name in cte_names_in_join 397 and ambiguous_col.alias_or_name in cte.this.named_selects 398 ] 399 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 400 # use the same CTE we used before 401 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 402 if cte: 403 resolved_column_position[ambiguous_col] += 1 404 else: 405 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 406 ambiguous_col.expression.set("table", cte.alias_or_name) 407 return self.copy( 408 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 409 ) 410 411 @operation(Operation.NO_OP) 412 def alias(self, name: str, **kwargs) -> DataFrame: 413 new_sequence_id = self.spark._random_sequence_id 414 df = self.copy() 415 for join_hint in df.pending_join_hints: 416 for expression in join_hint.expressions: 417 if expression.alias_or_name == self.sequence_id: 418 expression.set("this", Column.ensure_col(new_sequence_id).expression) 419 df.spark._add_alias_to_mapping(name, new_sequence_id) 420 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 421 422 @operation(Operation.WHERE) 423 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 424 col = self._ensure_and_normalize_col(column) 425 return self.copy(expression=self.expression.where(col.expression)) 426 427 filter = where 428 429 @operation(Operation.GROUP_BY) 430 def groupBy(self, *cols, **kwargs) -> GroupedData: 431 columns = self._ensure_and_normalize_cols(cols) 432 return GroupedData(self, columns, self.last_op) 433 434 @operation(Operation.SELECT) 435 def agg(self, *exprs, **kwargs) -> DataFrame: 436 cols = self._ensure_and_normalize_cols(exprs) 437 return self.groupBy().agg(*cols) 438 439 @operation(Operation.FROM) 440 def join( 441 self, 442 other_df: DataFrame, 443 on: t.Union[str, t.List[str], Column, t.List[Column]], 444 how: str = "inner", 445 **kwargs, 446 ) -> DataFrame: 447 other_df = other_df._convert_leaf_to_cte() 448 join_columns = self._ensure_list_of_columns(on) 449 # We will determine actual "join on" expression later so we don't provide it at first 450 join_expression = self.expression.join( 451 other_df.latest_cte_name, join_type=how.replace("_", " ") 452 ) 453 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 454 self_columns = self._get_outer_select_columns(join_expression) 455 other_columns = self._get_outer_select_columns(other_df) 456 # Determines the join clause and select columns to be used passed on what type of columns were provided for 457 # the join. The columns returned changes based on how the on expression is provided. 458 if isinstance(join_columns[0].expression, exp.Column): 459 """ 460 Unique characteristics of join on column names only: 461 * The column names are put at the front of the select list 462 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 463 """ 464 table_names = [ 465 table.alias_or_name 466 for table in get_tables_from_expression_with_join(join_expression) 467 ] 468 potential_ctes = [ 469 cte 470 for cte in join_expression.ctes 471 if cte.alias_or_name in table_names 472 and cte.alias_or_name != other_df.latest_cte_name 473 ] 474 # Determine the table to reference for the left side of the join by checking each of the left side 475 # tables and see if they have the column being referenced. 476 join_column_pairs = [] 477 for join_column in join_columns: 478 num_matching_ctes = 0 479 for cte in potential_ctes: 480 if join_column.alias_or_name in cte.this.named_selects: 481 left_column = join_column.copy().set_table_name(cte.alias_or_name) 482 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 483 join_column_pairs.append((left_column, right_column)) 484 num_matching_ctes += 1 485 if num_matching_ctes > 1: 486 raise ValueError( 487 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 488 ) 489 elif num_matching_ctes == 0: 490 raise ValueError( 491 f"Column {join_column.alias_or_name} does not exist in any of the tables." 492 ) 493 join_clause = functools.reduce( 494 lambda x, y: x & y, 495 [left_column == right_column for left_column, right_column in join_column_pairs], 496 ) 497 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 498 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 499 select_column_names = [ 500 column.alias_or_name 501 if not isinstance(column.expression.this, exp.Star) 502 else column.sql() 503 for column in self_columns + other_columns 504 ] 505 select_column_names = [ 506 column_name 507 for column_name in select_column_names 508 if column_name not in join_column_names 509 ] 510 select_column_names = join_column_names + select_column_names 511 else: 512 """ 513 Unique characteristics of join on expressions: 514 * There is no deduplication of the results. 515 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 516 """ 517 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 518 if len(join_columns) > 1: 519 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 520 join_clause = join_columns[0] 521 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 522 523 # Update the on expression with the actual join clause to replace the dummy one from before 524 join_expression.args["joins"][-1].set("on", join_clause.expression) 525 new_df = self.copy(expression=join_expression) 526 new_df.pending_join_hints.extend(self.pending_join_hints) 527 new_df.pending_hints.extend(other_df.pending_hints) 528 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 529 return new_df 530 531 @operation(Operation.ORDER_BY) 532 def orderBy( 533 self, 534 *cols: t.Union[str, Column], 535 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 536 ) -> DataFrame: 537 """ 538 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 539 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 540 is unlikely to come up. 541 """ 542 columns = self._ensure_and_normalize_cols(cols) 543 pre_ordered_col_indexes = [ 544 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 545 ] 546 if ascending is None: 547 ascending = [True] * len(columns) 548 elif not isinstance(ascending, list): 549 ascending = [ascending] * len(columns) 550 ascending = [bool(x) for i, x in enumerate(ascending)] 551 assert len(columns) == len( 552 ascending 553 ), "The length of items in ascending must equal the number of columns provided" 554 col_and_ascending = list(zip(columns, ascending)) 555 order_by_columns = [ 556 exp.Ordered(this=col.expression, desc=not asc) 557 if i not in pre_ordered_col_indexes 558 else columns[i].column_expression 559 for i, (col, asc) in enumerate(col_and_ascending) 560 ] 561 return self.copy(expression=self.expression.order_by(*order_by_columns)) 562 563 sort = orderBy 564 565 @operation(Operation.FROM) 566 def union(self, other: DataFrame) -> DataFrame: 567 return self._set_operation(exp.Union, other, False) 568 569 unionAll = union 570 571 @operation(Operation.FROM) 572 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 573 l_columns = self.columns 574 r_columns = other.columns 575 if not allowMissingColumns: 576 l_expressions = l_columns 577 r_expressions = l_columns 578 else: 579 l_expressions = [] 580 r_expressions = [] 581 r_columns_unused = copy(r_columns) 582 for l_column in l_columns: 583 l_expressions.append(l_column) 584 if l_column in r_columns: 585 r_expressions.append(l_column) 586 r_columns_unused.remove(l_column) 587 else: 588 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 589 for r_column in r_columns_unused: 590 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 591 r_expressions.append(r_column) 592 r_df = ( 593 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 594 ) 595 l_df = self.copy() 596 if allowMissingColumns: 597 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 598 return l_df._set_operation(exp.Union, r_df, False) 599 600 @operation(Operation.FROM) 601 def intersect(self, other: DataFrame) -> DataFrame: 602 return self._set_operation(exp.Intersect, other, True) 603 604 @operation(Operation.FROM) 605 def intersectAll(self, other: DataFrame) -> DataFrame: 606 return self._set_operation(exp.Intersect, other, False) 607 608 @operation(Operation.FROM) 609 def exceptAll(self, other: DataFrame) -> DataFrame: 610 return self._set_operation(exp.Except, other, False) 611 612 @operation(Operation.SELECT) 613 def distinct(self) -> DataFrame: 614 return self.copy(expression=self.expression.distinct()) 615 616 @operation(Operation.SELECT) 617 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 618 if not subset: 619 return self.distinct() 620 column_names = ensure_list(subset) 621 window = Window.partitionBy(*column_names).orderBy(*column_names) 622 return ( 623 self.copy() 624 .withColumn("row_num", F.row_number().over(window)) 625 .where(F.col("row_num") == F.lit(1)) 626 .drop("row_num") 627 ) 628 629 @operation(Operation.FROM) 630 def dropna( 631 self, 632 how: str = "any", 633 thresh: t.Optional[int] = None, 634 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 635 ) -> DataFrame: 636 minimum_non_null = thresh or 0 # will be determined later if thresh is null 637 new_df = self.copy() 638 all_columns = self._get_outer_select_columns(new_df.expression) 639 if subset: 640 null_check_columns = self._ensure_and_normalize_cols(subset) 641 else: 642 null_check_columns = all_columns 643 if thresh is None: 644 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 645 else: 646 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 647 if minimum_num_nulls > len(null_check_columns): 648 raise RuntimeError( 649 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 650 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 651 ) 652 if_null_checks = [ 653 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 654 ] 655 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 656 num_nulls = nulls_added_together.alias("num_nulls") 657 new_df = new_df.select(num_nulls, append=True) 658 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 659 final_df = filtered_df.select(*all_columns) 660 return final_df 661 662 @operation(Operation.FROM) 663 def fillna( 664 self, 665 value: t.Union[ColumnLiterals], 666 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 667 ) -> DataFrame: 668 """ 669 Functionality Difference: If you provide a value to replace a null and that type conflicts 670 with the type of the column then PySpark will just ignore your replacement. 671 This will try to cast them to be the same in some cases. So they won't always match. 672 Best to not mix types so make sure replacement is the same type as the column 673 674 Possibility for improvement: Use `typeof` function to get the type of the column 675 and check if it matches the type of the value provided. If not then make it null. 676 """ 677 from sqlglot.dataframe.sql.functions import lit 678 679 values = None 680 columns = None 681 new_df = self.copy() 682 all_columns = self._get_outer_select_columns(new_df.expression) 683 all_column_mapping = {column.alias_or_name: column for column in all_columns} 684 if isinstance(value, dict): 685 values = list(value.values()) 686 columns = self._ensure_and_normalize_cols(list(value)) 687 if not columns: 688 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 689 if not values: 690 values = [value] * len(columns) 691 value_columns = [lit(value) for value in values] 692 693 null_replacement_mapping = { 694 column.alias_or_name: ( 695 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 696 ) 697 for column, value in zip(columns, value_columns) 698 } 699 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 700 null_replacement_columns = [ 701 null_replacement_mapping[column.alias_or_name] for column in all_columns 702 ] 703 new_df = new_df.select(*null_replacement_columns) 704 return new_df 705 706 @operation(Operation.FROM) 707 def replace( 708 self, 709 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 710 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 711 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 712 ) -> DataFrame: 713 from sqlglot.dataframe.sql.functions import lit 714 715 old_values = None 716 new_df = self.copy() 717 all_columns = self._get_outer_select_columns(new_df.expression) 718 all_column_mapping = {column.alias_or_name: column for column in all_columns} 719 720 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 721 if isinstance(to_replace, dict): 722 old_values = list(to_replace) 723 new_values = list(to_replace.values()) 724 elif not old_values and isinstance(to_replace, list): 725 assert isinstance(value, list), "value must be a list since the replacements are a list" 726 assert len(to_replace) == len( 727 value 728 ), "the replacements and values must be the same length" 729 old_values = to_replace 730 new_values = value 731 else: 732 old_values = [to_replace] * len(columns) 733 new_values = [value] * len(columns) 734 old_values = [lit(value) for value in old_values] 735 new_values = [lit(value) for value in new_values] 736 737 replacement_mapping = {} 738 for column in columns: 739 expression = Column(None) 740 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 741 if i == 0: 742 expression = F.when(column == old_value, new_value) 743 else: 744 expression = expression.when(column == old_value, new_value) # type: ignore 745 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 746 column.expression.alias_or_name 747 ) 748 749 replacement_mapping = {**all_column_mapping, **replacement_mapping} 750 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 751 new_df = new_df.select(*replacement_columns) 752 return new_df 753 754 @operation(Operation.SELECT) 755 def withColumn(self, colName: str, col: Column) -> DataFrame: 756 col = self._ensure_and_normalize_col(col) 757 existing_col_names = self.expression.named_selects 758 existing_col_index = ( 759 existing_col_names.index(colName) if colName in existing_col_names else None 760 ) 761 if existing_col_index: 762 expression = self.expression.copy() 763 expression.expressions[existing_col_index] = col.expression 764 return self.copy(expression=expression) 765 return self.copy().select(col.alias(colName), append=True) 766 767 @operation(Operation.SELECT) 768 def withColumnRenamed(self, existing: str, new: str): 769 expression = self.expression.copy() 770 existing_columns = [ 771 expression 772 for expression in expression.expressions 773 if expression.alias_or_name == existing 774 ] 775 if not existing_columns: 776 raise ValueError("Tried to rename a column that doesn't exist") 777 for existing_column in existing_columns: 778 if isinstance(existing_column, exp.Column): 779 existing_column.replace(exp.alias_(existing_column, new)) 780 else: 781 existing_column.set("alias", exp.to_identifier(new)) 782 return self.copy(expression=expression) 783 784 @operation(Operation.SELECT) 785 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 786 all_columns = self._get_outer_select_columns(self.expression) 787 drop_cols = self._ensure_and_normalize_cols(cols) 788 new_columns = [ 789 col 790 for col in all_columns 791 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 792 ] 793 return self.copy().select(*new_columns, append=False) 794 795 @operation(Operation.LIMIT) 796 def limit(self, num: int) -> DataFrame: 797 return self.copy(expression=self.expression.limit(num)) 798 799 @operation(Operation.NO_OP) 800 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 801 parameter_list = ensure_list(parameters) 802 parameter_columns = ( 803 self._ensure_list_of_columns(parameter_list) 804 if parameters 805 else Column.ensure_cols([self.sequence_id]) 806 ) 807 return self._hint(name, parameter_columns) 808 809 @operation(Operation.NO_OP) 810 def repartition( 811 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 812 ) -> DataFrame: 813 num_partition_cols = self._ensure_list_of_columns(numPartitions) 814 columns = self._ensure_and_normalize_cols(cols) 815 args = num_partition_cols + columns 816 return self._hint("repartition", args) 817 818 @operation(Operation.NO_OP) 819 def coalesce(self, numPartitions: int) -> DataFrame: 820 num_partitions = Column.ensure_cols([numPartitions]) 821 return self._hint("coalesce", num_partitions) 822 823 @operation(Operation.NO_OP) 824 def cache(self) -> DataFrame: 825 return self._cache(storage_level="MEMORY_AND_DISK") 826 827 @operation(Operation.NO_OP) 828 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 829 """ 830 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 831 """ 832 return self._cache(storageLevel)
50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select()
301 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 302 from sqlglot.dataframe.sql.session import SparkSession 303 304 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 305 306 df = self._resolve_pending_hints() 307 select_expressions = df._get_select_expressions() 308 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 309 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 310 311 for expression_type, select_expression in select_expressions: 312 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 313 if optimize: 314 quote_identifiers(select_expression, dialect=dialect) 315 select_expression = t.cast( 316 exp.Select, optimize_func(select_expression, dialect=dialect) 317 ) 318 319 select_expression = df._replace_cte_names_with_hashes(select_expression) 320 321 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 322 if expression_type == exp.Cache: 323 cache_table_name = df._create_hash_from_expression(select_expression) 324 cache_table = exp.to_table(cache_table_name) 325 original_alias_name = select_expression.args["cte_alias_name"] 326 327 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 328 cache_table_name 329 ) 330 sqlglot.schema.add_table( 331 cache_table_name, 332 { 333 expression.alias_or_name: expression.type.sql(dialect=dialect) 334 for expression in select_expression.expressions 335 }, 336 dialect=dialect, 337 ) 338 339 cache_storage_level = select_expression.args["cache_storage_level"] 340 options = [ 341 exp.Literal.string("storageLevel"), 342 exp.Literal.string(cache_storage_level), 343 ] 344 expression = exp.Cache( 345 this=cache_table, expression=select_expression, lazy=True, options=options 346 ) 347 348 # We will drop the "view" if it exists before running the cache table 349 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 350 elif expression_type == exp.Create: 351 expression = df.output_expression_container.copy() 352 expression.set("expression", select_expression) 353 elif expression_type == exp.Insert: 354 expression = df.output_expression_container.copy() 355 select_without_ctes = select_expression.copy() 356 select_without_ctes.set("with", None) 357 expression.set("expression", select_without_ctes) 358 359 if select_expression.ctes: 360 expression.set("with", exp.With(expressions=select_expression.ctes)) 361 elif expression_type == exp.Select: 362 expression = select_expression 363 else: 364 raise ValueError(f"Invalid expression type: {expression_type}") 365 366 output_expressions.append(expression) 367 368 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
373 @operation(Operation.SELECT) 374 def select(self, *cols, **kwargs) -> DataFrame: 375 cols = self._ensure_and_normalize_cols(cols) 376 kwargs["append"] = kwargs.get("append", False) 377 if self.expression.args.get("joins"): 378 ambiguous_cols = [ 379 col 380 for col in cols 381 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 382 ] 383 if ambiguous_cols: 384 join_table_identifiers = [ 385 x.this for x in get_tables_from_expression_with_join(self.expression) 386 ] 387 cte_names_in_join = [x.this for x in join_table_identifiers] 388 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 389 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 390 # of Spark. 391 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 392 for ambiguous_col in ambiguous_cols: 393 ctes_with_column = [ 394 cte 395 for cte in self.expression.ctes 396 if cte.alias_or_name in cte_names_in_join 397 and ambiguous_col.alias_or_name in cte.this.named_selects 398 ] 399 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 400 # use the same CTE we used before 401 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 402 if cte: 403 resolved_column_position[ambiguous_col] += 1 404 else: 405 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 406 ambiguous_col.expression.set("table", cte.alias_or_name) 407 return self.copy( 408 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 409 )
411 @operation(Operation.NO_OP) 412 def alias(self, name: str, **kwargs) -> DataFrame: 413 new_sequence_id = self.spark._random_sequence_id 414 df = self.copy() 415 for join_hint in df.pending_join_hints: 416 for expression in join_hint.expressions: 417 if expression.alias_or_name == self.sequence_id: 418 expression.set("this", Column.ensure_col(new_sequence_id).expression) 419 df.spark._add_alias_to_mapping(name, new_sequence_id) 420 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
439 @operation(Operation.FROM) 440 def join( 441 self, 442 other_df: DataFrame, 443 on: t.Union[str, t.List[str], Column, t.List[Column]], 444 how: str = "inner", 445 **kwargs, 446 ) -> DataFrame: 447 other_df = other_df._convert_leaf_to_cte() 448 join_columns = self._ensure_list_of_columns(on) 449 # We will determine actual "join on" expression later so we don't provide it at first 450 join_expression = self.expression.join( 451 other_df.latest_cte_name, join_type=how.replace("_", " ") 452 ) 453 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 454 self_columns = self._get_outer_select_columns(join_expression) 455 other_columns = self._get_outer_select_columns(other_df) 456 # Determines the join clause and select columns to be used passed on what type of columns were provided for 457 # the join. The columns returned changes based on how the on expression is provided. 458 if isinstance(join_columns[0].expression, exp.Column): 459 """ 460 Unique characteristics of join on column names only: 461 * The column names are put at the front of the select list 462 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 463 """ 464 table_names = [ 465 table.alias_or_name 466 for table in get_tables_from_expression_with_join(join_expression) 467 ] 468 potential_ctes = [ 469 cte 470 for cte in join_expression.ctes 471 if cte.alias_or_name in table_names 472 and cte.alias_or_name != other_df.latest_cte_name 473 ] 474 # Determine the table to reference for the left side of the join by checking each of the left side 475 # tables and see if they have the column being referenced. 476 join_column_pairs = [] 477 for join_column in join_columns: 478 num_matching_ctes = 0 479 for cte in potential_ctes: 480 if join_column.alias_or_name in cte.this.named_selects: 481 left_column = join_column.copy().set_table_name(cte.alias_or_name) 482 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 483 join_column_pairs.append((left_column, right_column)) 484 num_matching_ctes += 1 485 if num_matching_ctes > 1: 486 raise ValueError( 487 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 488 ) 489 elif num_matching_ctes == 0: 490 raise ValueError( 491 f"Column {join_column.alias_or_name} does not exist in any of the tables." 492 ) 493 join_clause = functools.reduce( 494 lambda x, y: x & y, 495 [left_column == right_column for left_column, right_column in join_column_pairs], 496 ) 497 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 498 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 499 select_column_names = [ 500 column.alias_or_name 501 if not isinstance(column.expression.this, exp.Star) 502 else column.sql() 503 for column in self_columns + other_columns 504 ] 505 select_column_names = [ 506 column_name 507 for column_name in select_column_names 508 if column_name not in join_column_names 509 ] 510 select_column_names = join_column_names + select_column_names 511 else: 512 """ 513 Unique characteristics of join on expressions: 514 * There is no deduplication of the results. 515 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 516 """ 517 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 518 if len(join_columns) > 1: 519 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 520 join_clause = join_columns[0] 521 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 522 523 # Update the on expression with the actual join clause to replace the dummy one from before 524 join_expression.args["joins"][-1].set("on", join_clause.expression) 525 new_df = self.copy(expression=join_expression) 526 new_df.pending_join_hints.extend(self.pending_join_hints) 527 new_df.pending_hints.extend(other_df.pending_hints) 528 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 529 return new_df
531 @operation(Operation.ORDER_BY) 532 def orderBy( 533 self, 534 *cols: t.Union[str, Column], 535 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 536 ) -> DataFrame: 537 """ 538 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 539 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 540 is unlikely to come up. 541 """ 542 columns = self._ensure_and_normalize_cols(cols) 543 pre_ordered_col_indexes = [ 544 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 545 ] 546 if ascending is None: 547 ascending = [True] * len(columns) 548 elif not isinstance(ascending, list): 549 ascending = [ascending] * len(columns) 550 ascending = [bool(x) for i, x in enumerate(ascending)] 551 assert len(columns) == len( 552 ascending 553 ), "The length of items in ascending must equal the number of columns provided" 554 col_and_ascending = list(zip(columns, ascending)) 555 order_by_columns = [ 556 exp.Ordered(this=col.expression, desc=not asc) 557 if i not in pre_ordered_col_indexes 558 else columns[i].column_expression 559 for i, (col, asc) in enumerate(col_and_ascending) 560 ] 561 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
531 @operation(Operation.ORDER_BY) 532 def orderBy( 533 self, 534 *cols: t.Union[str, Column], 535 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 536 ) -> DataFrame: 537 """ 538 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 539 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 540 is unlikely to come up. 541 """ 542 columns = self._ensure_and_normalize_cols(cols) 543 pre_ordered_col_indexes = [ 544 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 545 ] 546 if ascending is None: 547 ascending = [True] * len(columns) 548 elif not isinstance(ascending, list): 549 ascending = [ascending] * len(columns) 550 ascending = [bool(x) for i, x in enumerate(ascending)] 551 assert len(columns) == len( 552 ascending 553 ), "The length of items in ascending must equal the number of columns provided" 554 col_and_ascending = list(zip(columns, ascending)) 555 order_by_columns = [ 556 exp.Ordered(this=col.expression, desc=not asc) 557 if i not in pre_ordered_col_indexes 558 else columns[i].column_expression 559 for i, (col, asc) in enumerate(col_and_ascending) 560 ] 561 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
571 @operation(Operation.FROM) 572 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 573 l_columns = self.columns 574 r_columns = other.columns 575 if not allowMissingColumns: 576 l_expressions = l_columns 577 r_expressions = l_columns 578 else: 579 l_expressions = [] 580 r_expressions = [] 581 r_columns_unused = copy(r_columns) 582 for l_column in l_columns: 583 l_expressions.append(l_column) 584 if l_column in r_columns: 585 r_expressions.append(l_column) 586 r_columns_unused.remove(l_column) 587 else: 588 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 589 for r_column in r_columns_unused: 590 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 591 r_expressions.append(r_column) 592 r_df = ( 593 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 594 ) 595 l_df = self.copy() 596 if allowMissingColumns: 597 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 598 return l_df._set_operation(exp.Union, r_df, False)
616 @operation(Operation.SELECT) 617 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 618 if not subset: 619 return self.distinct() 620 column_names = ensure_list(subset) 621 window = Window.partitionBy(*column_names).orderBy(*column_names) 622 return ( 623 self.copy() 624 .withColumn("row_num", F.row_number().over(window)) 625 .where(F.col("row_num") == F.lit(1)) 626 .drop("row_num") 627 )
629 @operation(Operation.FROM) 630 def dropna( 631 self, 632 how: str = "any", 633 thresh: t.Optional[int] = None, 634 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 635 ) -> DataFrame: 636 minimum_non_null = thresh or 0 # will be determined later if thresh is null 637 new_df = self.copy() 638 all_columns = self._get_outer_select_columns(new_df.expression) 639 if subset: 640 null_check_columns = self._ensure_and_normalize_cols(subset) 641 else: 642 null_check_columns = all_columns 643 if thresh is None: 644 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 645 else: 646 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 647 if minimum_num_nulls > len(null_check_columns): 648 raise RuntimeError( 649 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 650 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 651 ) 652 if_null_checks = [ 653 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 654 ] 655 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 656 num_nulls = nulls_added_together.alias("num_nulls") 657 new_df = new_df.select(num_nulls, append=True) 658 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 659 final_df = filtered_df.select(*all_columns) 660 return final_df
662 @operation(Operation.FROM) 663 def fillna( 664 self, 665 value: t.Union[ColumnLiterals], 666 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 667 ) -> DataFrame: 668 """ 669 Functionality Difference: If you provide a value to replace a null and that type conflicts 670 with the type of the column then PySpark will just ignore your replacement. 671 This will try to cast them to be the same in some cases. So they won't always match. 672 Best to not mix types so make sure replacement is the same type as the column 673 674 Possibility for improvement: Use `typeof` function to get the type of the column 675 and check if it matches the type of the value provided. If not then make it null. 676 """ 677 from sqlglot.dataframe.sql.functions import lit 678 679 values = None 680 columns = None 681 new_df = self.copy() 682 all_columns = self._get_outer_select_columns(new_df.expression) 683 all_column_mapping = {column.alias_or_name: column for column in all_columns} 684 if isinstance(value, dict): 685 values = list(value.values()) 686 columns = self._ensure_and_normalize_cols(list(value)) 687 if not columns: 688 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 689 if not values: 690 values = [value] * len(columns) 691 value_columns = [lit(value) for value in values] 692 693 null_replacement_mapping = { 694 column.alias_or_name: ( 695 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 696 ) 697 for column, value in zip(columns, value_columns) 698 } 699 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 700 null_replacement_columns = [ 701 null_replacement_mapping[column.alias_or_name] for column in all_columns 702 ] 703 new_df = new_df.select(*null_replacement_columns) 704 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
706 @operation(Operation.FROM) 707 def replace( 708 self, 709 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 710 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 711 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 712 ) -> DataFrame: 713 from sqlglot.dataframe.sql.functions import lit 714 715 old_values = None 716 new_df = self.copy() 717 all_columns = self._get_outer_select_columns(new_df.expression) 718 all_column_mapping = {column.alias_or_name: column for column in all_columns} 719 720 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 721 if isinstance(to_replace, dict): 722 old_values = list(to_replace) 723 new_values = list(to_replace.values()) 724 elif not old_values and isinstance(to_replace, list): 725 assert isinstance(value, list), "value must be a list since the replacements are a list" 726 assert len(to_replace) == len( 727 value 728 ), "the replacements and values must be the same length" 729 old_values = to_replace 730 new_values = value 731 else: 732 old_values = [to_replace] * len(columns) 733 new_values = [value] * len(columns) 734 old_values = [lit(value) for value in old_values] 735 new_values = [lit(value) for value in new_values] 736 737 replacement_mapping = {} 738 for column in columns: 739 expression = Column(None) 740 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 741 if i == 0: 742 expression = F.when(column == old_value, new_value) 743 else: 744 expression = expression.when(column == old_value, new_value) # type: ignore 745 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 746 column.expression.alias_or_name 747 ) 748 749 replacement_mapping = {**all_column_mapping, **replacement_mapping} 750 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 751 new_df = new_df.select(*replacement_columns) 752 return new_df
754 @operation(Operation.SELECT) 755 def withColumn(self, colName: str, col: Column) -> DataFrame: 756 col = self._ensure_and_normalize_col(col) 757 existing_col_names = self.expression.named_selects 758 existing_col_index = ( 759 existing_col_names.index(colName) if colName in existing_col_names else None 760 ) 761 if existing_col_index: 762 expression = self.expression.copy() 763 expression.expressions[existing_col_index] = col.expression 764 return self.copy(expression=expression) 765 return self.copy().select(col.alias(colName), append=True)
767 @operation(Operation.SELECT) 768 def withColumnRenamed(self, existing: str, new: str): 769 expression = self.expression.copy() 770 existing_columns = [ 771 expression 772 for expression in expression.expressions 773 if expression.alias_or_name == existing 774 ] 775 if not existing_columns: 776 raise ValueError("Tried to rename a column that doesn't exist") 777 for existing_column in existing_columns: 778 if isinstance(existing_column, exp.Column): 779 existing_column.replace(exp.alias_(existing_column, new)) 780 else: 781 existing_column.set("alias", exp.to_identifier(new)) 782 return self.copy(expression=expression)
784 @operation(Operation.SELECT) 785 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 786 all_columns = self._get_outer_select_columns(self.expression) 787 drop_cols = self._ensure_and_normalize_cols(cols) 788 new_columns = [ 789 col 790 for col in all_columns 791 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 792 ] 793 return self.copy().select(*new_columns, append=False)
799 @operation(Operation.NO_OP) 800 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 801 parameter_list = ensure_list(parameters) 802 parameter_columns = ( 803 self._ensure_list_of_columns(parameter_list) 804 if parameters 805 else Column.ensure_cols([self.sequence_id]) 806 ) 807 return self._hint(name, parameter_columns)
809 @operation(Operation.NO_OP) 810 def repartition( 811 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 812 ) -> DataFrame: 813 num_partition_cols = self._ensure_list_of_columns(numPartitions) 814 columns = self._ensure_and_normalize_cols(cols) 815 args = num_partition_cols + columns 816 return self._hint("repartition", args)
827 @operation(Operation.NO_OP) 828 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 829 """ 830 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 831 """ 832 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: [Column.ensure_col(x).expression for x in v] 149 if is_iterable(v) 150 else Column.ensure_col(v).expression 151 for k, v in kwargs.items() 152 if v is not None 153 } 154 new_expression = ( 155 callable_expression(**ensure_expression_values) 156 if ensured_column is None 157 else callable_expression( 158 this=ensured_column.column_expression, **ensure_expression_values 159 ) 160 ) 161 return Column(new_expression) 162 163 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 164 return Column( 165 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 166 ) 167 168 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 169 return Column( 170 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 171 ) 172 173 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 174 return Column(klass(this=self.column_expression, **kwargs)) 175 176 @property 177 def is_alias(self): 178 return isinstance(self.expression, exp.Alias) 179 180 @property 181 def is_column(self): 182 return isinstance(self.expression, exp.Column) 183 184 @property 185 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 186 return self.expression.unalias() 187 188 @property 189 def alias_or_name(self) -> str: 190 return self.expression.alias_or_name 191 192 @classmethod 193 def ensure_literal(cls, value) -> Column: 194 from sqlglot.dataframe.sql.functions import lit 195 196 if isinstance(value, cls): 197 value = value.expression 198 if not isinstance(value, exp.Literal): 199 return lit(value) 200 return Column(value) 201 202 def copy(self) -> Column: 203 return Column(self.expression.copy()) 204 205 def set_table_name(self, table_name: str, copy=False) -> Column: 206 expression = self.expression.copy() if copy else self.expression 207 expression.set("table", exp.to_identifier(table_name)) 208 return Column(expression) 209 210 def sql(self, **kwargs) -> str: 211 from sqlglot.dataframe.sql.session import SparkSession 212 213 return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 214 215 def alias(self, name: str) -> Column: 216 from sqlglot.dataframe.sql.session import SparkSession 217 218 dialect = SparkSession().dialect 219 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 220 new_expression = exp.alias_( 221 self.column_expression, 222 alias.this if isinstance(alias, exp.Column) else name, 223 dialect=dialect, 224 ) 225 return Column(new_expression) 226 227 def asc(self) -> Column: 228 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 229 return Column(new_expression) 230 231 def desc(self) -> Column: 232 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 233 return Column(new_expression) 234 235 asc_nulls_first = asc 236 237 def asc_nulls_last(self) -> Column: 238 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 239 return Column(new_expression) 240 241 def desc_nulls_first(self) -> Column: 242 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 243 return Column(new_expression) 244 245 desc_nulls_last = desc 246 247 def when(self, condition: Column, value: t.Any) -> Column: 248 from sqlglot.dataframe.sql.functions import when 249 250 column_with_if = when(condition, value) 251 if not isinstance(self.expression, exp.Case): 252 return column_with_if 253 new_column = self.copy() 254 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 255 return new_column 256 257 def otherwise(self, value: t.Any) -> Column: 258 from sqlglot.dataframe.sql.functions import lit 259 260 true_value = value if isinstance(value, Column) else lit(value) 261 new_column = self.copy() 262 new_column.expression.set("default", true_value.column_expression) 263 return new_column 264 265 def isNull(self) -> Column: 266 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 267 return Column(new_expression) 268 269 def isNotNull(self) -> Column: 270 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 271 return Column(new_expression) 272 273 def cast(self, dataType: t.Union[str, DataType]) -> Column: 274 """ 275 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 276 Sqlglot doesn't currently replicate this class so it only accepts a string 277 """ 278 from sqlglot.dataframe.sql.session import SparkSession 279 280 if isinstance(dataType, DataType): 281 dataType = dataType.simpleString() 282 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) 283 284 def startswith(self, value: t.Union[str, Column]) -> Column: 285 value = self._lit(value) if not isinstance(value, Column) else value 286 return self.invoke_anonymous_function(self, "STARTSWITH", value) 287 288 def endswith(self, value: t.Union[str, Column]) -> Column: 289 value = self._lit(value) if not isinstance(value, Column) else value 290 return self.invoke_anonymous_function(self, "ENDSWITH", value) 291 292 def rlike(self, regexp: str) -> Column: 293 return self.invoke_expression_over_column( 294 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 295 ) 296 297 def like(self, other: str): 298 return self.invoke_expression_over_column( 299 self, exp.Like, expression=self._lit(other).expression 300 ) 301 302 def ilike(self, other: str): 303 return self.invoke_expression_over_column( 304 self, exp.ILike, expression=self._lit(other).expression 305 ) 306 307 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 308 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 309 length = self._lit(length) if not isinstance(length, Column) else length 310 return Column.invoke_expression_over_column( 311 self, exp.Substring, start=startPos.expression, length=length.expression 312 ) 313 314 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 315 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 316 expressions = [self._lit(x).expression for x in columns] 317 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 318 319 def between( 320 self, 321 lowerBound: t.Union[ColumnOrLiteral], 322 upperBound: t.Union[ColumnOrLiteral], 323 ) -> Column: 324 lower_bound_exp = ( 325 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 326 ) 327 upper_bound_exp = ( 328 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 329 ) 330 return Column( 331 exp.Between( 332 this=self.column_expression, 333 low=lower_bound_exp.expression, 334 high=upper_bound_exp.expression, 335 ) 336 ) 337 338 def over(self, window: WindowSpec) -> Column: 339 window_expression = window.expression.copy() 340 window_expression.set("this", self.column_expression) 341 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: [Column.ensure_col(x).expression for x in v] 149 if is_iterable(v) 150 else Column.ensure_col(v).expression 151 for k, v in kwargs.items() 152 if v is not None 153 } 154 new_expression = ( 155 callable_expression(**ensure_expression_values) 156 if ensured_column is None 157 else callable_expression( 158 this=ensured_column.column_expression, **ensure_expression_values 159 ) 160 ) 161 return Column(new_expression)
215 def alias(self, name: str) -> Column: 216 from sqlglot.dataframe.sql.session import SparkSession 217 218 dialect = SparkSession().dialect 219 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 220 new_expression = exp.alias_( 221 self.column_expression, 222 alias.this if isinstance(alias, exp.Column) else name, 223 dialect=dialect, 224 ) 225 return Column(new_expression)
247 def when(self, condition: Column, value: t.Any) -> Column: 248 from sqlglot.dataframe.sql.functions import when 249 250 column_with_if = when(condition, value) 251 if not isinstance(self.expression, exp.Case): 252 return column_with_if 253 new_column = self.copy() 254 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 255 return new_column
273 def cast(self, dataType: t.Union[str, DataType]) -> Column: 274 """ 275 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 276 Sqlglot doesn't currently replicate this class so it only accepts a string 277 """ 278 from sqlglot.dataframe.sql.session import SparkSession 279 280 if isinstance(dataType, DataType): 281 dataType = dataType.simpleString() 282 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
307 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 308 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 309 length = self._lit(length) if not isinstance(length, Column) else length 310 return Column.invoke_expression_over_column( 311 self, exp.Substring, start=startPos.expression, length=length.expression 312 )
314 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 315 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 316 expressions = [self._lit(x).expression for x in columns] 317 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
319 def between( 320 self, 321 lowerBound: t.Union[ColumnOrLiteral], 322 upperBound: t.Union[ColumnOrLiteral], 323 ) -> Column: 324 lower_bound_exp = ( 325 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 326 ) 327 upper_bound_exp = ( 328 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 329 ) 330 return Column( 331 exp.Between( 332 this=self.column_expression, 333 low=lower_bound_exp.expression, 334 high=upper_bound_exp.expression, 335 ) 336 )
835class DataFrameNaFunctions: 836 def __init__(self, df: DataFrame): 837 self.df = df 838 839 def drop( 840 self, 841 how: str = "any", 842 thresh: t.Optional[int] = None, 843 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 844 ) -> DataFrame: 845 return self.df.dropna(how=how, thresh=thresh, subset=subset) 846 847 def fill( 848 self, 849 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 850 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 851 ) -> DataFrame: 852 return self.df.fillna(value=value, subset=subset) 853 854 def replace( 855 self, 856 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 857 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 858 subset: t.Optional[t.Union[str, t.List[str]]] = None, 859 ) -> DataFrame: 860 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
854 def replace( 855 self, 856 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 857 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 858 subset: t.Optional[t.Union[str, t.List[str]]] = None, 859 ) -> DataFrame: 860 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 from sqlglot.dataframe.sql.session import SparkSession 53 54 return self.expression.sql(dialect=SparkSession().dialect, **kwargs) 55 56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec 66 67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec 79 80 def _calc_start_end( 81 self, start: int, end: int 82 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 83 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 84 "start_side": None, 85 "end_side": None, 86 } 87 if start == Window.currentRow: 88 kwargs["start"] = "CURRENT ROW" 89 else: 90 kwargs = { 91 **kwargs, 92 **{ 93 "start_side": "PRECEDING", 94 "start": "UNBOUNDED" 95 if start <= Window.unboundedPreceding 96 else F.lit(start).expression, 97 }, 98 } 99 if end == Window.currentRow: 100 kwargs["end"] = "CURRENT ROW" 101 else: 102 kwargs = { 103 **kwargs, 104 **{ 105 "end_side": "FOLLOWING", 106 "end": "UNBOUNDED" 107 if end >= Window.unboundedFollowing 108 else F.lit(end).expression, 109 }, 110 } 111 return kwargs 112 113 def rowsBetween(self, start: int, end: int) -> WindowSpec: 114 window_spec = self.copy() 115 spec = self._calc_start_end(start, end) 116 spec["kind"] = "ROWS" 117 window_spec.expression.set( 118 "spec", 119 exp.WindowSpec( 120 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 121 ), 122 ) 123 return window_spec 124 125 def rangeBetween(self, start: int, end: int) -> WindowSpec: 126 window_spec = self.copy() 127 spec = self._calc_start_end(start, end) 128 spec["kind"] = "RANGE" 129 window_spec.expression.set( 130 "spec", 131 exp.WindowSpec( 132 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 133 ), 134 ) 135 return window_spec
56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec
67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec
113 def rowsBetween(self, start: int, end: int) -> WindowSpec: 114 window_spec = self.copy() 115 spec = self._calc_start_end(start, end) 116 spec["kind"] = "ROWS" 117 window_spec.expression.set( 118 "spec", 119 exp.WindowSpec( 120 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 121 ), 122 ) 123 return window_spec
125 def rangeBetween(self, start: int, end: int) -> WindowSpec: 126 window_spec = self.copy() 127 spec = self._calc_start_end(start, end) 128 spec["kind"] = "RANGE" 129 window_spec.expression.set( 130 "spec", 131 exp.WindowSpec( 132 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 133 ), 134 ) 135 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
44class DataFrameWriter: 45 def __init__( 46 self, 47 df: DataFrame, 48 spark: t.Optional[SparkSession] = None, 49 mode: t.Optional[str] = None, 50 by_name: bool = False, 51 ): 52 self._df = df 53 self._spark = spark or df.spark 54 self._mode = mode 55 self._by_name = by_name 56 57 def copy(self, **kwargs) -> DataFrameWriter: 58 return DataFrameWriter( 59 **{ 60 k[1:] if k.startswith("_") else k: v 61 for k, v in object_to_dict(self, **kwargs).items() 62 } 63 ) 64 65 def sql(self, **kwargs) -> t.List[str]: 66 return self._df.sql(**kwargs) 67 68 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 69 return self.copy(_mode=saveMode) 70 71 @property 72 def byName(self): 73 return self.copy(by_name=True) 74 75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df) 92 93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df)
93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))