summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/__init__.py3
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi20
-rw-r--r--sqlglot/dataframe/sql/dataframe.py2
3 files changed, 9 insertions, 16 deletions
diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py
index e69de29..a57e990 100644
--- a/sqlglot/dataframe/__init__.py
+++ b/sqlglot/dataframe/__init__.py
@@ -0,0 +1,3 @@
+"""
+.. include:: ./README.md
+"""
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
index 67c8c09..1682ec1 100644
--- a/sqlglot/dataframe/sql/_typing.pyi
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -9,18 +9,8 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.types import StructType
-ColumnLiterals = t.TypeVar(
- "ColumnLiterals",
- bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
-)
-ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
-ColumnOrLiteral = t.TypeVar(
- "ColumnOrLiteral",
- bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
-)
-SchemaInput = t.TypeVar(
- "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
-)
-OutputExpressionContainer = t.TypeVar(
- "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
-)
+ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ColumnOrName = t.Union[Column, str]
+ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
+OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 3c45741..a17bb9d 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -634,7 +634,7 @@ class DataFrame:
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
if isinstance(value, dict):
- values = value.values()
+ values = list(value.values())
columns = self._ensure_and_normalize_cols(list(value))
if not columns:
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns