summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/util.py')
-rw-r--r--sqlglot/dataframe/sql/util.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
new file mode 100644
index 0000000..575d18a
--- /dev/null
+++ b/sqlglot/dataframe/sql/util.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import types
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import SchemaInput
+
+
+def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
+ if isinstance(schema, dict):
+ return schema
+ elif isinstance(schema, str):
+ col_name_type_strs = [x.strip() for x in schema.split(",")]
+ return {
+ name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
+ for name_type_str in col_name_type_strs
+ }
+ elif isinstance(schema, types.StructType):
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
+ return {x.strip(): None for x in schema} # type: ignore
+
+
+def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
+ if not expression.args.get("joins"):
+ return []
+
+ left_table = expression.args["from"].args["expressions"][0]
+ other_tables = [join.this for join in expression.args["joins"]]
+ return [left_table] + other_tables