summaryrefslogtreecommitdiffstats
path: root/pgcli/packages/parseutils/meta.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/packages/parseutils/meta.py')
-rw-r--r--pgcli/packages/parseutils/meta.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
new file mode 100644
index 0000000..108c01a
--- /dev/null
+++ b/pgcli/packages/parseutils/meta.py
@@ -0,0 +1,170 @@
+from collections import namedtuple
+
+_ColumnMetadata = namedtuple(
+ "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
+)
+
+
+def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
+ return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
+
+
+ForeignKey = namedtuple(
+ "ForeignKey",
+ [
+ "parentschema",
+ "parenttable",
+ "parentcolumn",
+ "childschema",
+ "childtable",
+ "childcolumn",
+ ],
+)
+TableMetadata = namedtuple("TableMetadata", "name columns")
+
+
+def parse_defaults(defaults_string):
+ """Yields default values for a function, given the string provided by
+ pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
+ if not defaults_string:
+ return
+ current = ""
+ in_quote = None
+ for char in defaults_string:
+ if current == "" and char == " ":
+ # Skip space after comma separating default expressions
+ continue
+ if char == '"' or char == "'":
+ if in_quote and char == in_quote:
+ # End quote
+ in_quote = None
+ elif not in_quote:
+ # Begin quote
+ in_quote = char
+ elif char == "," and not in_quote:
+ # End of expression
+ yield current
+ current = ""
+ continue
+ current += char
+ yield current
+
+
+class FunctionMetadata(object):
+ def __init__(
+ self,
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ ):
+ """Class for describing a postgresql function"""
+
+ self.schema_name = schema_name
+ self.func_name = func_name
+
+ self.arg_modes = tuple(arg_modes) if arg_modes else None
+ self.arg_names = tuple(arg_names) if arg_names else None
+
+ # Be flexible in not requiring arg_types -- use None as a placeholder
+ # for each arg. (Used for compatibility with old versions of postgresql
+ # where such info is hard to get.
+ if arg_types:
+ self.arg_types = tuple(arg_types)
+ elif arg_modes:
+ self.arg_types = tuple([None] * len(arg_modes))
+ elif arg_names:
+ self.arg_types = tuple([None] * len(arg_names))
+ else:
+ self.arg_types = None
+
+ self.arg_defaults = tuple(parse_defaults(arg_defaults))
+
+ self.return_type = return_type.strip()
+ self.is_aggregate = is_aggregate
+ self.is_window = is_window
+ self.is_set_returning = is_set_returning
+ self.is_extension = bool(is_extension)
+ self.is_public = self.schema_name and self.schema_name == "public"
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _signature(self):
+ return (
+ self.schema_name,
+ self.func_name,
+ self.arg_names,
+ self.arg_types,
+ self.arg_modes,
+ self.return_type,
+ self.is_aggregate,
+ self.is_window,
+ self.is_set_returning,
+ self.is_extension,
+ self.arg_defaults,
+ )
+
+ def __hash__(self):
+ return hash(self._signature())
+
+ def __repr__(self):
+ return (
+ "%s(schema_name=%r, func_name=%r, arg_names=%r, "
+ "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
+ "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
+ ) % ((self.__class__.__name__,) + self._signature())
+
+ def has_variadic(self):
+ return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
+
+ def args(self):
+ """Returns a list of input-parameter ColumnMetadata namedtuples."""
+ if not self.arg_names:
+ return []
+ modes = self.arg_modes or ["i"] * len(self.arg_names)
+ args = [
+ (name, typ)
+ for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
+ if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
+ ]
+
+ def arg(name, typ, num):
+ num_args = len(args)
+ num_defaults = len(self.arg_defaults)
+ has_default = num + num_defaults >= num_args
+ default = (
+ self.arg_defaults[num - num_args + num_defaults]
+ if has_default
+ else None
+ )
+ return ColumnMetadata(name, typ, [], default, has_default)
+
+ return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
+
+ def fields(self):
+ """Returns a list of output-field ColumnMetadata namedtuples"""
+
+ if self.return_type.lower() == "void":
+ return []
+ elif not self.arg_modes:
+ # For functions without output parameters, the function name
+ # is used as the name of the output column.
+ # E.g. 'SELECT unnest FROM unnest(...);'
+ return [ColumnMetadata(self.func_name, self.return_type, [])]
+
+ return [
+ ColumnMetadata(name, typ, [])
+ for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
+ if mode in ("o", "b", "t")
+ ] # OUT, INOUT, TABLE