diff options
Diffstat (limited to 'sphinx/domains/cpp/_ast.py')
-rw-r--r-- | sphinx/domains/cpp/_ast.py | 847 |
1 files changed, 822 insertions, 25 deletions
diff --git a/sphinx/domains/cpp/_ast.py b/sphinx/domains/cpp/_ast.py index ad57695..141d511 100644 --- a/sphinx/domains/cpp/_ast.py +++ b/sphinx/domains/cpp/_ast.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import sys +import warnings +from typing import TYPE_CHECKING, Any, ClassVar, Literal from docutils import nodes @@ -44,60 +46,64 @@ class ASTBase(ASTBaseBase): ################################################################################ class ASTIdentifier(ASTBase): - def __init__(self, identifier: str) -> None: - assert identifier is not None - assert len(identifier) != 0 - self.identifier = identifier + def __init__(self, name: str) -> None: + if not isinstance(name, str) or len(name) == 0: + raise AssertionError + self.name = sys.intern(name) + self.is_anonymous = name[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: - if type(other) is not ASTIdentifier: + if not isinstance(other, ASTIdentifier): return NotImplemented - return self.identifier == other.identifier + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) def _stringify(self, transform: StringifyTransform) -> str: - return transform(self.identifier) + return transform(self.name) def is_anon(self) -> bool: - return self.identifier[0] == '@' + return self.is_anonymous def get_id(self, version: int) -> str: - if self.is_anon() and version < 3: + if self.is_anonymous and version < 3: raise NoOldIdError if version == 1: - if self.identifier == 'size_t': + if self.name == 'size_t': return 's' else: - return self.identifier - if self.identifier == "std": + return self.name + if self.name == "std": return 'St' - elif self.identifier[0] == "~": + elif self.name[0] == "~": # a destructor, just use an arbitrary version of dtors return 'D0' else: - if self.is_anon(): - return 'Ut%d_%s' % (len(self.identifier) - 1, self.identifier[1:]) + if self.is_anonymous: + return 'Ut%d_%s' % (len(self.name) - 1, self.name[1:]) else: - return str(len(self.identifier)) + self.identifier + return str(len(self.name)) + self.name # and this is where we finally make a difference between __str__ and the display string def __str__(self) -> str: - return self.identifier + return self.name def get_display_string(self) -> str: - return "[anonymous]" if self.is_anon() else self.identifier + return "[anonymous]" if self.is_anonymous else self.name def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, templateArgs: str, symbol: Symbol) -> None: verify_description_mode(mode) - if self.is_anon(): + if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: - node = addnodes.desc_sig_name(self.identifier, self.identifier) + node = addnodes.desc_sig_name(self.name, self.name) if mode == 'markType': - targetText = prefix + self.identifier + templateArgs + targetText = prefix + self.name + templateArgs pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', reftarget=targetText, modname=None, @@ -118,8 +124,8 @@ class ASTIdentifier(ASTBase): # the target is 'operator""id' instead of just 'id' assert len(prefix) == 0 assert len(templateArgs) == 0 - assert not self.is_anon() - targetText = 'operator""' + self.identifier + assert not self.is_anonymous + targetText = 'operator""' + self.name pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', reftarget=targetText, modname=None, @@ -130,6 +136,14 @@ class ASTIdentifier(ASTBase): else: raise Exception('Unknown description mode: %s' % mode) + @property + def identifier(self) -> str: + warnings.warn( + '`ASTIdentifier.identifier` is deprecated, use `ASTIdentifier.name` instead', + DeprecationWarning, stacklevel=2, + ) + return self.name + class ASTNestedNameElement(ASTBase): def __init__(self, identOrOp: ASTIdentifier | ASTOperator, @@ -137,6 +151,14 @@ class ASTNestedNameElement(ASTBase): self.identOrOp = identOrOp self.templateArgs = templateArgs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedNameElement): + return NotImplemented + return self.identOrOp == other.identOrOp and self.templateArgs == other.templateArgs + + def __hash__(self) -> int: + return hash((self.identOrOp, self.templateArgs)) + def is_operator(self) -> bool: return False @@ -169,6 +191,18 @@ class ASTNestedName(ASTBase): assert len(self.names) == len(self.templates) self.rooted = rooted + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedName): + return NotImplemented + return ( + self.names == other.names + and self.templates == other.templates + and self.rooted == other.rooted + ) + + def __hash__(self) -> int: + return hash((self.names, self.templates, self.rooted)) + @property def name(self) -> ASTNestedName: return self @@ -316,6 +350,12 @@ class ASTLiteral(ASTExpression): class ASTPointerLiteral(ASTLiteral): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPointerLiteral) + + def __hash__(self) -> int: + return hash('nullptr') + def _stringify(self, transform: StringifyTransform) -> str: return 'nullptr' @@ -331,6 +371,14 @@ class ASTBooleanLiteral(ASTLiteral): def __init__(self, value: bool) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBooleanLiteral): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: if self.value: return 'true' @@ -352,6 +400,14 @@ class ASTNumberLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNumberLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -368,6 +424,14 @@ class ASTStringLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStringLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -392,6 +456,17 @@ class ASTCharLiteral(ASTLiteral): else: raise UnsupportedMultiCharacterCharLiteral(decoded) + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCharLiteral): + return NotImplemented + return ( + self.prefix == other.prefix + and self.value == other.value + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.value)) + def _stringify(self, transform: StringifyTransform) -> str: if self.prefix is None: return "'" + self.data + "'" @@ -415,6 +490,14 @@ class ASTUserDefinedLiteral(ASTLiteral): self.literal = literal self.ident = ident + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUserDefinedLiteral): + return NotImplemented + return self.literal == other.literal and self.ident == other.ident + + def __hash__(self) -> int: + return hash((self.literal, self.ident)) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.literal) + transform(self.ident) @@ -431,6 +514,12 @@ class ASTUserDefinedLiteral(ASTLiteral): ################################################################################ class ASTThisLiteral(ASTExpression): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTThisLiteral) + + def __hash__(self) -> int: + return hash("this") + def _stringify(self, transform: StringifyTransform) -> str: return "this" @@ -450,6 +539,18 @@ class ASTFoldExpr(ASTExpression): self.op = op self.rightExpr = rightExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFoldExpr): + return NotImplemented + return ( + self.leftExpr == other.leftExpr + and self.op == other.op + and self.rightExpr == other.rightExpr + ) + + def __hash__(self) -> int: + return hash((self.leftExpr, self.op, self.rightExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] if self.leftExpr: @@ -508,6 +609,14 @@ class ASTParenExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '(' + transform(self.expr) + ')' @@ -526,6 +635,14 @@ class ASTIdExpression(ASTExpression): # note: this class is basically to cast a nested name as an expression self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdExpression): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.name) @@ -553,6 +670,14 @@ class ASTPostfixArray(ASTPostfixOp): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixArray): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '[' + transform(self.expr) + ']' @@ -570,6 +695,14 @@ class ASTPostfixMember(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMember): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '.' + transform(self.name) @@ -586,6 +719,14 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMemberOfPointer): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '->' + transform(self.name) @@ -599,6 +740,12 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): class ASTPostfixInc(ASTPostfixOp): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPostfixInc) + + def __hash__(self) -> int: + return hash('++') + def _stringify(self, transform: StringifyTransform) -> str: return '++' @@ -611,6 +758,12 @@ class ASTPostfixInc(ASTPostfixOp): class ASTPostfixDec(ASTPostfixOp): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPostfixDec) + + def __hash__(self) -> int: + return hash('--') + def _stringify(self, transform: StringifyTransform) -> str: return '--' @@ -626,6 +779,14 @@ class ASTPostfixCallExpr(ASTPostfixOp): def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None: self.lst = lst + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixCallExpr): + return NotImplemented + return self.lst == other.lst + + def __hash__(self) -> int: + return hash(self.lst) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.lst) @@ -647,6 +808,14 @@ class ASTPostfixExpr(ASTExpression): self.prefix = prefix self.postFixes = postFixes + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixExpr): + return NotImplemented + return self.prefix == other.prefix and self.postFixes == other.postFixes + + def __hash__(self) -> int: + return hash((self.prefix, self.postFixes)) + def _stringify(self, transform: StringifyTransform) -> str: return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)]) @@ -670,6 +839,14 @@ class ASTExplicitCast(ASTExpression): self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTExplicitCast): + return NotImplemented + return self.cast == other.cast and self.typ == other.typ and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.cast, self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [self.cast] res.append('<') @@ -700,6 +877,14 @@ class ASTTypeId(ASTExpression): self.typeOrExpr = typeOrExpr self.isType = isType + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeId): + return NotImplemented + return self.typeOrExpr == other.typeOrExpr and self.isType == other.isType + + def __hash__(self) -> int: + return hash((self.typeOrExpr, self.isType)) + def _stringify(self, transform: StringifyTransform) -> str: return 'typeid(' + transform(self.typeOrExpr) + ')' @@ -723,6 +908,14 @@ class ASTUnaryOpExpr(ASTExpression): self.op = op self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnaryOpExpr): + return NotImplemented + return self.op == other.op and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.op, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: if self.op[0] in 'cn': return self.op + " " + transform(self.expr) @@ -746,6 +939,14 @@ class ASTSizeofParamPack(ASTExpression): def __init__(self, identifier: ASTIdentifier) -> None: self.identifier = identifier + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofParamPack): + return NotImplemented + return self.identifier == other.identifier + + def __hash__(self) -> int: + return hash(self.identifier) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof...(" + transform(self.identifier) + ")" @@ -766,6 +967,14 @@ class ASTSizeofType(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofType): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof(" + transform(self.typ) + ")" @@ -784,6 +993,14 @@ class ASTSizeofExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof " + transform(self.expr) @@ -801,6 +1018,14 @@ class ASTAlignofExpr(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAlignofExpr): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "alignof(" + transform(self.typ) + ")" @@ -819,6 +1044,14 @@ class ASTNoexceptExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNoexceptExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'noexcept(' + transform(self.expr) + ')' @@ -841,6 +1074,19 @@ class ASTNewExpr(ASTExpression): self.typ = typ self.initList = initList + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNewExpr): + return NotImplemented + return ( + self.rooted == other.rooted + and self.isNewTypeId == other.isNewTypeId + and self.typ == other.typ + and self.initList == other.initList + ) + + def __hash__(self) -> int: + return hash((self.rooted, self.isNewTypeId, self.typ, self.initList)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.rooted: @@ -888,6 +1134,18 @@ class ASTDeleteExpr(ASTExpression): self.array = array self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeleteExpr): + return NotImplemented + return ( + self.rooted == other.rooted + and self.array == other.array + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.rooted, self.array, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.rooted: @@ -925,6 +1183,17 @@ class ASTCastExpr(ASTExpression): self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCastExpr): + return NotImplemented + return ( + self.typ == other.typ + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] res.append(transform(self.typ)) @@ -950,6 +1219,17 @@ class ASTBinOpExpr(ASTExpression): self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBinOpExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -990,6 +1270,18 @@ class ASTConditionalExpr(ASTExpression): self.thenExpr = thenExpr self.elseExpr = elseExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTConditionalExpr): + return NotImplemented + return ( + self.ifExpr == other.ifExpr + and self.thenExpr == other.thenExpr + and self.elseExpr == other.elseExpr + ) + + def __hash__(self) -> int: + return hash((self.ifExpr, self.thenExpr, self.elseExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.ifExpr)) @@ -1027,6 +1319,14 @@ class ASTBracedInitList(ASTBase): self.exprs = exprs self.trailingComma = trailingComma + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBracedInitList): + return NotImplemented + return self.exprs == other.exprs and self.trailingComma == other.trailingComma + + def __hash__(self) -> int: + return hash((self.exprs, self.trailingComma)) + def get_id(self, version: int) -> str: return "il%sE" % ''.join(e.get_id(version) for e in self.exprs) @@ -1059,6 +1359,18 @@ class ASTAssignmentExpr(ASTExpression): self.op = op self.rightExpr = rightExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAssignmentExpr): + return NotImplemented + return ( + self.leftExpr == other.leftExpr + and self.op == other.op + and self.rightExpr == other.rightExpr + ) + + def __hash__(self) -> int: + return hash((self.leftExpr, self.op, self.rightExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.leftExpr)) @@ -1093,6 +1405,14 @@ class ASTCommaExpr(ASTExpression): assert len(exprs) > 0 self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCommaExpr): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def _stringify(self, transform: StringifyTransform) -> str: return ', '.join(transform(e) for e in self.exprs) @@ -1118,6 +1438,14 @@ class ASTFallbackExpr(ASTExpression): def __init__(self, expr: str) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFallbackExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return self.expr @@ -1137,11 +1465,16 @@ class ASTFallbackExpr(ASTExpression): ################################################################################ class ASTOperator(ASTBase): + is_anonymous: ClassVar[Literal[False]] = False + def __eq__(self, other: object) -> bool: raise NotImplementedError(repr(self)) + def __hash__(self) -> int: + raise NotImplementedError(repr(self)) + def is_anon(self) -> bool: - return False + return self.is_anonymous def is_operator(self) -> bool: return True @@ -1193,6 +1526,9 @@ class ASTOperatorBuildIn(ASTOperator): return NotImplemented return self.op == other.op + def __hash__(self) -> int: + return hash(self.op) + def get_id(self, version: int) -> str: if version == 1: ids = _id_operator_v1 @@ -1228,6 +1564,9 @@ class ASTOperatorLiteral(ASTOperator): return NotImplemented return self.identifier == other.identifier + def __hash__(self) -> int: + return hash(self.identifier) + def get_id(self, version: int) -> str: if version == 1: raise NoOldIdError @@ -1252,6 +1591,9 @@ class ASTOperatorType(ASTOperator): return NotImplemented return self.type == other.type + def __hash__(self) -> int: + return hash(self.type) + def get_id(self, version: int) -> str: if version == 1: return 'castto-%s-operator' % self.type.get_id(version) @@ -1275,6 +1617,14 @@ class ASTTemplateArgConstant(ASTBase): def __init__(self, value: ASTExpression) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateArgConstant): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.value) @@ -1298,6 +1648,14 @@ class ASTTemplateArgs(ASTBase): self.args = args self.packExpansion = packExpansion + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateArgs): + return NotImplemented + return self.args == other.args and self.packExpansion == other.packExpansion + + def __hash__(self) -> int: + return hash((self.args, self.packExpansion)) + def get_id(self, version: int) -> str: if version == 1: res = [] @@ -1361,6 +1719,14 @@ class ASTTrailingTypeSpecFundamental(ASTTrailingTypeSpec): # the canonical name list is for ID lookup self.canonNames = canonNames + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecFundamental): + return NotImplemented + return self.names == other.names and self.canonNames == other.canonNames + + def __hash__(self) -> int: + return hash((self.names, self.canonNames)) + def _stringify(self, transform: StringifyTransform) -> str: return ' '.join(self.names) @@ -1394,6 +1760,12 @@ class ASTTrailingTypeSpecFundamental(ASTTrailingTypeSpec): class ASTTrailingTypeSpecDecltypeAuto(ASTTrailingTypeSpec): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTTrailingTypeSpecDecltypeAuto) + + def __hash__(self) -> int: + return hash('decltype(auto)') + def _stringify(self, transform: StringifyTransform) -> str: return 'decltype(auto)' @@ -1414,6 +1786,14 @@ class ASTTrailingTypeSpecDecltype(ASTTrailingTypeSpec): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecDecltype): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'decltype(' + transform(self.expr) + ')' @@ -1437,6 +1817,18 @@ class ASTTrailingTypeSpecName(ASTTrailingTypeSpec): self.nestedName = nestedName self.placeholderType = placeholderType + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecName): + return NotImplemented + return ( + self.prefix == other.prefix + and self.nestedName == other.nestedName + and self.placeholderType == other.placeholderType + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.nestedName, self.placeholderType)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -1480,6 +1872,14 @@ class ASTFunctionParameter(ASTBase): self.arg = arg self.ellipsis = ellipsis + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFunctionParameter): + return NotImplemented + return self.arg == other.arg and self.ellipsis == other.ellipsis + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis)) + def get_id( self, version: int, objectType: str | None = None, symbol: Symbol | None = None, ) -> str: @@ -1512,6 +1912,14 @@ class ASTNoexceptSpec(ASTBase): def __init__(self, expr: ASTExpression | None) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNoexceptSpec): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: if self.expr: return 'noexcept(' + transform(self.expr) + ')' @@ -1543,6 +1951,28 @@ class ASTParametersQualifiers(ASTBase): self.attrs = attrs self.initializer = initializer + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParametersQualifiers): + return NotImplemented + return ( + self.args == other.args + and self.volatile == other.volatile + and self.const == other.const + and self.refQual == other.refQual + and self.exceptionSpec == other.exceptionSpec + and self.trailingReturn == other.trailingReturn + and self.override == other.override + and self.final == other.final + and self.attrs == other.attrs + and self.initializer == other.initializer + ) + + def __hash__(self) -> int: + return hash(( + self.args, self.volatile, self.const, self.refQual, self.exceptionSpec, + self.trailingReturn, self.override, self.final, self.attrs, self.initializer + )) + @property def function_params(self) -> list[ASTFunctionParameter]: return self.args @@ -1681,6 +2111,14 @@ class ASTExplicitSpec(ASTBase): def __init__(self, expr: ASTExpression | None) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTExplicitSpec): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: res = ['explicit'] if self.expr is not None: @@ -1717,6 +2155,40 @@ class ASTDeclSpecsSimple(ASTBase): self.friend = friend self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecsSimple): + return NotImplemented + return ( + self.storage == other.storage + and self.threadLocal == other.threadLocal + and self.inline == other.inline + and self.virtual == other.virtual + and self.explicitSpec == other.explicitSpec + and self.consteval == other.consteval + and self.constexpr == other.constexpr + and self.constinit == other.constinit + and self.volatile == other.volatile + and self.const == other.const + and self.friend == other.friend + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash(( + self.storage, + self.threadLocal, + self.inline, + self.virtual, + self.explicitSpec, + self.consteval, + self.constexpr, + self.constinit, + self.volatile, + self.const, + self.friend, + self.attrs, + )) + def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple: if not other: return self @@ -1811,6 +2283,24 @@ class ASTDeclSpecs(ASTBase): self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs) self.trailingTypeSpec = trailing + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecs): + return NotImplemented + return ( + self.outer == other.outer + and self.leftSpecs == other.leftSpecs + and self.rightSpecs == other.rightSpecs + and self.trailingTypeSpec == other.trailingTypeSpec + ) + + def __hash__(self) -> int: + return hash(( + self.outer, + self.leftSpecs, + self.rightSpecs, + self.trailingTypeSpec, + )) + def get_id(self, version: int) -> str: if version == 1: res = [] @@ -1873,6 +2363,14 @@ class ASTArray(ASTBase): def __init__(self, size: ASTExpression) -> None: self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTArray): + return NotImplemented + return self.size == other.size + + def __hash__(self) -> int: + return hash(self.size) + def _stringify(self, transform: StringifyTransform) -> str: if self.size: return '[' + transform(self.size) + ']' @@ -1953,6 +2451,18 @@ class ASTDeclaratorNameParamQual(ASTDeclarator): self.arrayOps = arrayOps self.paramQual = paramQual + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameParamQual): + return NotImplemented + return ( + self.declId == other.declId + and self.arrayOps == other.arrayOps + and self.paramQual == other.paramQual + ) + + def __hash__(self) -> int: + return hash((self.declId, self.arrayOps, self.paramQual)) + @property def name(self) -> ASTNestedName: return self.declId @@ -2037,6 +2547,14 @@ class ASTDeclaratorNameBitField(ASTDeclarator): self.declId = declId self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameBitField): + return NotImplemented + return self.declId == other.declId and self.size == other.size + + def __hash__(self) -> int: + return hash((self.declId, self.size)) + @property def name(self) -> ASTNestedName: return self.declId @@ -2087,6 +2605,19 @@ class ASTDeclaratorPtr(ASTDeclarator): self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorPtr): + return NotImplemented + return ( + self.next == other.next + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.next, self.volatile, self.const, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2192,6 +2723,14 @@ class ASTDeclaratorRef(ASTDeclarator): self.next = next self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorRef): + return NotImplemented + return self.next == other.next and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.next, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2258,6 +2797,14 @@ class ASTDeclaratorParamPack(ASTDeclarator): assert next self.next = next + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParamPack): + return NotImplemented + return self.next == other.next + + def __hash__(self) -> int: + return hash(self.next) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2326,6 +2873,19 @@ class ASTDeclaratorMemPtr(ASTDeclarator): self.volatile = volatile self.next = next + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorMemPtr): + return NotImplemented + return ( + self.className == other.className + and self.const == other.const + and self.volatile == other.volatile + and self.next == other.next + ) + + def __hash__(self) -> int: + return hash((self.className, self.const, self.volatile, self.next)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2424,6 +2984,14 @@ class ASTDeclaratorParen(ASTDeclarator): self.next = next # TODO: we assume the name, params, and qualifiers are in inner + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParen): + return NotImplemented + return self.inner == other.inner and self.next == other.next + + def __hash__(self) -> int: + return hash((self.inner, self.next)) + @property def name(self) -> ASTNestedName: return self.inner.name @@ -2493,6 +3061,14 @@ class ASTPackExpansionExpr(ASTExpression): def __init__(self, expr: ASTExpression | ASTBracedInitList) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPackExpansionExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.expr) + '...' @@ -2510,6 +3086,14 @@ class ASTParenExprList(ASTBaseParenExprList): def __init__(self, exprs: list[ASTExpression | ASTBracedInitList]) -> None: self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExprList): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def get_id(self, version: int) -> str: return "pi%sE" % ''.join(e.get_id(version) for e in self.exprs) @@ -2538,6 +3122,14 @@ class ASTInitializer(ASTBase): self.value = value self.hasAssign = hasAssign + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTInitializer): + return NotImplemented + return self.value == other.value and self.hasAssign == other.hasAssign + + def __hash__(self) -> int: + return hash((self.value, self.hasAssign)) + def _stringify(self, transform: StringifyTransform) -> str: val = transform(self.value) if self.hasAssign: @@ -2562,6 +3154,14 @@ class ASTType(ASTBase): self.declSpecs = declSpecs self.decl = decl + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTType): + return NotImplemented + return self.declSpecs == other.declSpecs and self.decl == other.decl + + def __hash__(self) -> int: + return hash((self.declSpecs, self.decl)) + @property def name(self) -> ASTNestedName: return self.decl.name @@ -2671,6 +3271,14 @@ class ASTTemplateParamConstrainedTypeWithInit(ASTBase): self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamConstrainedTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -2712,6 +3320,14 @@ class ASTTypeWithInit(ASTBase): self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -2749,6 +3365,14 @@ class ASTTypeUsing(ASTBase): self.name = name self.type = type + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeUsing): + return NotImplemented + return self.name == other.name and self.type == other.type + + def __hash__(self) -> int: + return hash((self.name, self.type)) + def get_id(self, version: int, objectType: str | None = None, symbol: Symbol | None = None) -> str: if version == 1: @@ -2785,6 +3409,14 @@ class ASTConcept(ASTBase): self.nestedName = nestedName self.initializer = initializer + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTConcept): + return NotImplemented + return self.nestedName == other.nestedName and self.initializer == other.initializer + + def __hash__(self) -> int: + return hash((self.nestedName, self.initializer)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -2816,6 +3448,19 @@ class ASTBaseClass(ASTBase): self.virtual = virtual self.pack = pack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBaseClass): + return NotImplemented + return ( + self.name == other.name + and self.visibility == other.visibility + and self.virtual == other.virtual + and self.pack == other.pack + ) + + def __hash__(self) -> int: + return hash((self.name, self.visibility, self.virtual, self.pack)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.visibility is not None: @@ -2851,6 +3496,19 @@ class ASTClass(ASTBase): self.bases = bases self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTClass): + return NotImplemented + return ( + self.name == other.name + and self.final == other.final + and self.bases == other.bases + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.final, self.bases, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -2899,6 +3557,14 @@ class ASTUnion(ASTBase): self.name = name self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnion): + return NotImplemented + return self.name == other.name and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.name, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -2929,6 +3595,19 @@ class ASTEnum(ASTBase): self.underlyingType = underlyingType self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnum): + return NotImplemented + return ( + self.name == other.name + and self.scoped == other.scoped + and self.underlyingType == other.underlyingType + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.scoped, self.underlyingType, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -2971,6 +3650,18 @@ class ASTEnumerator(ASTBase): self.init = init self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnumerator): + return NotImplemented + return ( + self.name == other.name + and self.init == other.init + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.init, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -3035,6 +3726,19 @@ class ASTTemplateKeyParamPackIdDefault(ASTTemplateParam): self.parameterPack = parameterPack self.default = default + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateKeyParamPackIdDefault): + return NotImplemented + return ( + self.key == other.key + and self.identifier == other.identifier + and self.parameterPack == other.parameterPack + and self.default == other.default + ) + + def __hash__(self) -> int: + return hash((self.key, self.identifier, self.parameterPack, self.default)) + def get_identifier(self) -> ASTIdentifier: return self.identifier @@ -3086,6 +3790,14 @@ class ASTTemplateParamType(ASTTemplateParam): assert data self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamType): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3125,6 +3837,17 @@ class ASTTemplateParamTemplateType(ASTTemplateParam): self.nestedParams = nestedParams self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamTemplateType): + return NotImplemented + return ( + self.nestedParams == other.nestedParams + and self.data == other.data + ) + + def __hash__(self) -> int: + return hash((self.nestedParams, self.data)) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3166,6 +3889,14 @@ class ASTTemplateParamNonType(ASTTemplateParam): self.param = param self.parameterPack = parameterPack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamNonType): + return NotImplemented + return ( + self.param == other.param + and self.parameterPack == other.parameterPack + ) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3221,6 +3952,14 @@ class ASTTemplateParams(ASTBase): self.params = params self.requiresClause = requiresClause + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParams): + return NotImplemented + return self.params == other.params and self.requiresClause == other.requiresClause + + def __hash__(self) -> int: + return hash((self.params, self.requiresClause)) + def get_id(self, version: int, excludeRequires: bool = False) -> str: assert version >= 2 res = [] @@ -3295,6 +4034,17 @@ class ASTTemplateIntroductionParameter(ASTBase): self.identifier = identifier self.parameterPack = parameterPack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateIntroductionParameter): + return NotImplemented + return ( + self.identifier == other.identifier + and self.parameterPack == other.parameterPack + ) + + def __hash__(self) -> int: + return hash((self.identifier, self.parameterPack)) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3351,6 +4101,14 @@ class ASTTemplateIntroduction(ASTBase): self.concept = concept self.params = params + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateIntroduction): + return NotImplemented + return self.concept == other.concept and self.params == other.params + + def __hash__(self) -> int: + return hash((self.concept, self.params)) + def get_id(self, version: int) -> str: assert version >= 2 return ''.join([ @@ -3402,6 +4160,14 @@ class ASTTemplateDeclarationPrefix(ASTBase): # templates is None means it's an explicit instantiation of a variable self.templates = templates + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateDeclarationPrefix): + return NotImplemented + return self.templates == other.templates + + def __hash__(self) -> int: + return hash(self.templates) + def get_requires_clause_in_last(self) -> ASTRequiresClause | None: if self.templates is None: return None @@ -3436,6 +4202,14 @@ class ASTRequiresClause(ASTBase): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTRequiresClause): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'requires ' + transform(self.expr) @@ -3472,6 +4246,21 @@ class ASTDeclaration(ASTBase): # further changes will be made to this object self._newest_id_cache: str | None = None + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaration): + return NotImplemented + return ( + self.objectType == other.objectType + and self.directiveType == other.directiveType + and self.visibility == other.visibility + and self.templatePrefix == other.templatePrefix + and self.declaration == other.declaration + and self.trailingRequiresClause == other.trailingRequiresClause + and self.semicolon == other.semicolon + and self.symbol == other.symbol + and self.enumeratorScopedSymbol == other.enumeratorScopedSymbol + ) + def clone(self) -> ASTDeclaration: templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None trailingRequiresClasueClone = self.trailingRequiresClause.clone() \ @@ -3627,6 +4416,14 @@ class ASTNamespace(ASTBase): self.nestedName = nestedName self.templatePrefix = templatePrefix + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNamespace): + return NotImplemented + return ( + self.nestedName == other.nestedName + and self.templatePrefix == other.templatePrefix + ) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.templatePrefix: |