from collections import namedtuple _ColumnMetadata = namedtuple( "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] ) def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default) ForeignKey = namedtuple( "ForeignKey", [ "parentschema", "parenttable", "parentcolumn", "childschema", "childtable", "childcolumn", ], ) TableMetadata = namedtuple("TableMetadata", "name columns") def parse_defaults(defaults_string): """Yields default values for a function, given the string provided by pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" if not defaults_string: return current = "" in_quote = None for char in defaults_string: if current == "" and char == " ": # Skip space after comma separating default expressions continue if char == '"' or char == "'": if in_quote and char == in_quote: # End quote in_quote = None elif not in_quote: # Begin quote in_quote = char elif char == "," and not in_quote: # End of expression yield current current = "" continue current += char yield current class FunctionMetadata(object): def __init__( self, schema_name, func_name, arg_names, arg_types, arg_modes, return_type, is_aggregate, is_window, is_set_returning, is_extension, arg_defaults, ): """Class for describing a postgresql function""" self.schema_name = schema_name self.func_name = func_name self.arg_modes = tuple(arg_modes) if arg_modes else None self.arg_names = tuple(arg_names) if arg_names else None # Be flexible in not requiring arg_types -- use None as a placeholder # for each arg. (Used for compatibility with old versions of postgresql # where such info is hard to get. if arg_types: self.arg_types = tuple(arg_types) elif arg_modes: self.arg_types = tuple([None] * len(arg_modes)) elif arg_names: self.arg_types = tuple([None] * len(arg_names)) else: self.arg_types = None self.arg_defaults = tuple(parse_defaults(arg_defaults)) self.return_type = return_type.strip() self.is_aggregate = is_aggregate self.is_window = is_window self.is_set_returning = is_set_returning self.is_extension = bool(is_extension) self.is_public = self.schema_name and self.schema_name == "public" def __eq__(self, other): return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) def _signature(self): return ( self.schema_name, self.func_name, self.arg_names, self.arg_types, self.arg_modes, self.return_type, self.is_aggregate, self.is_window, self.is_set_returning, self.is_extension, self.arg_defaults, ) def __hash__(self): return hash(self._signature()) def __repr__(self): return ( "%s(schema_name=%r, func_name=%r, arg_names=%r, " "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, " "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)" ) % ((self.__class__.__name__,) + self._signature()) def has_variadic(self): return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes) def args(self): """Returns a list of input-parameter ColumnMetadata namedtuples.""" if not self.arg_names: return [] modes = self.arg_modes or ["i"] * len(self.arg_names) args = [ (name, typ) for name, typ, mode in zip(self.arg_names, self.arg_types, modes) if mode in ("i", "b", "v") # IN, INOUT, VARIADIC ] def arg(name, typ, num): num_args = len(args) num_defaults = len(self.arg_defaults) has_default = num + num_defaults >= num_args default = ( self.arg_defaults[num - num_args + num_defaults] if has_default else None ) return ColumnMetadata(name, typ, [], default, has_default) return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] def fields(self): """Returns a list of output-field ColumnMetadata namedtuples""" if self.return_type.lower() == "void": return [] elif not self.arg_modes: # For functions without output parameters, the function name # is used as the name of the output column. # E.g. 'SELECT unnest FROM unnest(...);' return [ColumnMetadata(self.func_name, self.return_type, [])] return [ ColumnMetadata(name, typ, []) for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes) if mode in ("o", "b", "t") ] # OUT, INOUT, TABLE