summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r--sqlglot/helper.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index ed37e6c..5a0f2ac 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -131,7 +131,7 @@ def subclasses(
]
-def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
+def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
"""
Applies an offset to a given integer literal expression.
@@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
expression = expressions[0]
- if expression.is_int:
+ if expression and expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
- expression.args["this"] = str(int(expression.this) + offset)
+ expression.args["this"] = str(int(expression.this) + offset) # type: ignore
return [expression]
return expressions
@@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO:
return gzip.open(file_name, "rt", newline="")
- return open(file_name, "rt", encoding="utf-8", newline="")
+ return open(file_name, encoding="utf-8", newline="")
@contextmanager
@@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
file.close()
-def find_new_name(taken: t.Sequence[str], base: str) -> str:
+def find_new_name(taken: t.Collection[str], base: str) -> str:
"""
Searches for a new name.
@@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield value
+def count_params(function: t.Callable) -> int:
+ """
+ Returns the number of formal parameters expected by a function, without counting "self"
+ and "cls", in case of instance and class methods, respectively.
+ """
+ count = function.__code__.co_argcount
+ return count - 1 if inspect.ismethod(function) else count
+
+
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
@@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int:
Args:
d (dict): dictionary
+
Returns:
int: depth
"""