summaryrefslogtreecommitdiffstats
path: root/sqlglot/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/__init__.py')
-rw-r--r--sqlglot/__init__.py66
1 files changed, 59 insertions, 7 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index bfcabb3..714897f 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -33,7 +33,13 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.6.0"
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import DialectType
+
+ T = t.TypeVar("T", bound=Expression)
+
+
+__version__ = "10.6.3"
pretty = False
"""Whether to format generated SQL by default."""
@@ -42,9 +48,7 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
-def parse(
- sql: str, read: t.Optional[str | Dialect] = None, **opts
-) -> t.List[t.Optional[Expression]]:
+def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
"""
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
@@ -60,9 +64,57 @@ def parse(
return dialect.parse(sql, **opts)
+@t.overload
+def parse_one(
+ sql: str,
+ read: None = None,
+ into: t.Type[T] = ...,
+ **opts,
+) -> T:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ read: DialectType,
+ into: t.Type[T],
+ **opts,
+) -> T:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ read: None = None,
+ into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
+ **opts,
+) -> Expression:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ read: DialectType,
+ into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
+ **opts,
+) -> Expression:
+ ...
+
+
+@t.overload
+def parse_one(
+ sql: str,
+ **opts,
+) -> Expression:
+ ...
+
+
def parse_one(
sql: str,
- read: t.Optional[str | Dialect] = None,
+ read: DialectType = None,
into: t.Optional[exp.IntoType] = None,
**opts,
) -> Expression:
@@ -96,8 +148,8 @@ def parse_one(
def transpile(
sql: str,
- read: t.Optional[str | Dialect] = None,
- write: t.Optional[str | Dialect] = None,
+ read: DialectType = None,
+ write: DialectType = None,
identity: bool = True,
error_level: t.Optional[ErrorLevel] = None,
**opts,