summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0c2beba..1b20e0a 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect):
def get_or_raise(cls, dialect):
if not dialect:
return cls
+ if isinstance(dialect, _Dialect):
+ return dialect
+ if isinstance(dialect, Dialect):
+ return dialect.__class__
+
result = cls.get(dialect)
if not result:
raise ValueError(f"Unknown dialect '{dialect}'")
+
return result
@classmethod
@@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect):
)
+if t.TYPE_CHECKING:
+ DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
+
+
def rename_func(name):
def _rename(self, expression):
args = flatten(expression.args.values())