summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/util.py
blob: 575d18a490a733726730d58385c9ff2e7cbd0c22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import annotations

import typing as t

from sqlglot import expressions as exp
from sqlglot.dataframe.sql import types

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql._typing import SchemaInput


def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
    if isinstance(schema, dict):
        return schema
    elif isinstance(schema, str):
        col_name_type_strs = [x.strip() for x in schema.split(",")]
        return {
            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
            for name_type_str in col_name_type_strs
        }
    elif isinstance(schema, types.StructType):
        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
    return {x.strip(): None for x in schema}  # type: ignore


def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
    if not expression.args.get("joins"):
        return []

    left_table = expression.args["from"].args["expressions"][0]
    other_tables = [join.this for join in expression.args["joins"]]
    return [left_table] + other_tables