summaryrefslogtreecommitdiffstats
path: root/sqlglot/__init__.py
blob: 714897fb8f0f726779f12915e02041ee07070c66 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
.. include:: ../README.md
----
"""

from __future__ import annotations

import typing as t

from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff
from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError
from sqlglot.expressions import Expression
from sqlglot.expressions import alias_ as alias
from sqlglot.expressions import (
    and_,
    column,
    condition,
    except_,
    from_,
    intersect,
    maybe_parse,
    not_,
    or_,
    select,
    subquery,
)
from sqlglot.expressions import table_ as table
from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType

if t.TYPE_CHECKING:
    from sqlglot.dialects.dialect import DialectType

    T = t.TypeVar("T", bound=Expression)


__version__ = "10.6.3"

pretty = False
"""Whether to format generated SQL by default."""

schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""


def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
    """
    Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.

    Args:
        sql: the SQL code string to parse.
        read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
        **opts: other `sqlglot.parser.Parser` options.

    Returns:
        The resulting syntax tree collection.
    """
    dialect = Dialect.get_or_raise(read)()
    return dialect.parse(sql, **opts)


@t.overload
def parse_one(
    sql: str,
    read: None = None,
    into: t.Type[T] = ...,
    **opts,
) -> T:
    ...


@t.overload
def parse_one(
    sql: str,
    read: DialectType,
    into: t.Type[T],
    **opts,
) -> T:
    ...


@t.overload
def parse_one(
    sql: str,
    read: None = None,
    into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
    **opts,
) -> Expression:
    ...


@t.overload
def parse_one(
    sql: str,
    read: DialectType,
    into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
    **opts,
) -> Expression:
    ...


@t.overload
def parse_one(
    sql: str,
    **opts,
) -> Expression:
    ...


def parse_one(
    sql: str,
    read: DialectType = None,
    into: t.Optional[exp.IntoType] = None,
    **opts,
) -> Expression:
    """
    Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.

    Args:
        sql: the SQL code string to parse.
        read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
        into: the SQLGlot Expression to parse into.
        **opts: other `sqlglot.parser.Parser` options.

    Returns:
        The syntax tree for the first parsed statement.
    """

    dialect = Dialect.get_or_raise(read)()

    if into:
        result = dialect.parse_into(into, sql, **opts)
    else:
        result = dialect.parse(sql, **opts)

    for expression in result:
        if not expression:
            raise ParseError(f"No expression was parsed from '{sql}'")
        return expression
    else:
        raise ParseError(f"No expression was parsed from '{sql}'")


def transpile(
    sql: str,
    read: DialectType = None,
    write: DialectType = None,
    identity: bool = True,
    error_level: t.Optional[ErrorLevel] = None,
    **opts,
) -> t.List[str]:
    """
    Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed
    to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement.

    Args:
        sql: the SQL code string to transpile.
        read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql").
        write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql").
        identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
            the source and the target dialect.
        error_level: the desired error level of the parser.
        **opts: other `sqlglot.generator.Generator` options.

    Returns:
        The list of transpiled SQL statements.
    """
    write = write or read if identity else write
    return [
        Dialect.get_or_raise(write)().generate(expression, **opts)
        for expression in parse(sql, read, error_level=error_level)
    ]