summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py50
1 files changed, 35 insertions, 15 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 242e66c..264b8e9 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -878,11 +878,11 @@ class DerivedTable(Expression):
return [c.name for c in table_alias.args.get("columns") or []]
@property
- def selects(self):
+ def selects(self) -> t.List[Expression]:
return self.this.selects if isinstance(self.this, Subqueryable) else []
@property
- def named_selects(self):
+ def named_selects(self) -> t.List[str]:
return [select.output_name for select in self.selects]
@@ -959,7 +959,7 @@ class Unionable(Expression):
class UDTF(DerivedTable, Unionable):
@property
- def selects(self):
+ def selects(self) -> t.List[Expression]:
alias = self.args.get("alias")
return alias.columns if alias else []
@@ -1576,7 +1576,7 @@ class OnConflict(Expression):
class Returning(Expression):
- arg_types = {"expressions": True}
+ arg_types = {"expressions": True, "into": False}
# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
@@ -2194,11 +2194,11 @@ class Subqueryable(Unionable):
return with_.expressions
@property
- def selects(self):
+ def selects(self) -> t.List[Expression]:
raise NotImplementedError("Subqueryable objects must implement `selects`")
@property
- def named_selects(self):
+ def named_selects(self) -> t.List[str]:
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
def with_(
@@ -2282,7 +2282,6 @@ class Table(Expression):
"pivots": False,
"hints": False,
"system_time": False,
- "wrapped": False,
}
@property
@@ -2300,13 +2299,27 @@ class Table(Expression):
return self.text("catalog")
@property
+ def selects(self) -> t.List[Expression]:
+ return []
+
+ @property
+ def named_selects(self) -> t.List[str]:
+ return []
+
+ @property
def parts(self) -> t.List[Identifier]:
"""Return the parts of a table in order catalog, db, table."""
- return [
- t.cast(Identifier, self.args[part])
- for part in ("catalog", "db", "this")
- if self.args.get(part)
- ]
+ parts: t.List[Identifier] = []
+
+ for arg in ("catalog", "db", "this"):
+ part = self.args.get(arg)
+
+ if isinstance(part, Identifier):
+ parts.append(part)
+ elif isinstance(part, Dot):
+ parts.extend(part.flatten())
+
+ return parts
# See the TSQL "Querying data in a system-versioned temporal table" page
@@ -2390,7 +2403,7 @@ class Union(Subqueryable):
return this
@property
- def named_selects(self):
+ def named_selects(self) -> t.List[str]:
return self.this.unnest().named_selects
@property
@@ -2398,7 +2411,7 @@ class Union(Subqueryable):
return self.this.is_star or self.expression.is_star
@property
- def selects(self):
+ def selects(self) -> t.List[Expression]:
return self.this.unnest().selects
@property
@@ -3517,6 +3530,10 @@ class Or(Connector):
pass
+class Xor(Connector):
+ pass
+
+
class BitwiseAnd(Binary):
pass
@@ -4409,6 +4426,7 @@ class RegexpExtract(Func):
"expression": True,
"position": False,
"occurrence": False,
+ "parameters": False,
"group": False,
}
@@ -5756,7 +5774,9 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
raise ValueError(f"Cannot parse {table}")
return ".".join(
- part.sql(dialect=dialect) if not SAFE_IDENTIFIER_RE.match(part.name) else part.name
+ part.sql(dialect=dialect, identify=True)
+ if not SAFE_IDENTIFIER_RE.match(part.name)
+ else part.name
for part in table.parts
)