summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/readwriter.py
blob: 4830035fcfaab0c2a75ebc59b9758cd5e5d005fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

import typing as t

import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import object_to_dict

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql.dataframe import DataFrame
    from sqlglot.dataframe.sql.session import SparkSession


class DataFrameReader:
    def __init__(self, spark: SparkSession):
        self.spark = spark

    def table(self, tableName: str) -> DataFrame:
        from sqlglot.dataframe.sql.dataframe import DataFrame

        sqlglot.schema.add_table(tableName)
        return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))


class DataFrameWriter:
    def __init__(
        self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
    ):
        self._df = df
        self._spark = spark or df.spark
        self._mode = mode
        self._by_name = by_name

    def copy(self, **kwargs) -> DataFrameWriter:
        return DataFrameWriter(
            **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
        )

    def sql(self, **kwargs) -> t.List[str]:
        return self._df.sql(**kwargs)

    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
        return self.copy(_mode=saveMode)

    @property
    def byName(self):
        return self.copy(by_name=True)

    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
        output_expression_container = exp.Insert(
            **{
                "this": exp.to_table(tableName),
                "overwrite": overwrite,
            }
        )
        df = self._df.copy(output_expression_container=output_expression_container)
        if self._by_name:
            columns = sqlglot.schema.column_names(tableName, only_visible=True)
            df = df._convert_leaf_to_cte().select(*columns)

        return self.copy(_df=df)

    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
        if format is not None:
            raise NotImplementedError("Providing Format in the save as table is not supported")
        exists, replace, mode = None, None, mode or str(self._mode)
        if mode == "append":
            return self.insertInto(name)
        if mode == "ignore":
            exists = True
        if mode == "overwrite":
            replace = True
        output_expression_container = exp.Create(
            this=exp.to_table(name),
            kind="TABLE",
            exists=exists,
            replace=replace,
        )
        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))