summaryrefslogtreecommitdiffstats
path: root/markdown_it/ruler.py
blob: 11b937a0d8cab3386996deb90b93d3272097e745 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
class Ruler

Helper class, used by [[MarkdownIt#core]], [[MarkdownIt#block]] and
[[MarkdownIt#inline]] to manage sequences of functions (rules):

- keep rules in defined order
- assign the name to each rule
- enable/disable rules
- add/replace rules
- allow assign rules to additional named chains (in the same)
- caching lists of active rules

You will not need use this class directly until write plugins. For simple
rules control use [[MarkdownIt.disable]], [[MarkdownIt.enable]] and
[[MarkdownIt.use]].
"""
from __future__ import annotations

from collections.abc import Callable, Iterable, MutableMapping
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from markdown_it._compat import DATACLASS_KWARGS

if TYPE_CHECKING:
    from markdown_it import MarkdownIt


class StateBase:
    srcCharCode: tuple[int, ...]

    def __init__(self, src: str, md: MarkdownIt, env: MutableMapping):
        self.src = src
        self.env = env
        self.md = md

    @property
    def src(self) -> str:
        return self._src

    @src.setter
    def src(self, value: str) -> None:
        self._src = value
        self.srcCharCode = tuple(ord(c) for c in self.src)


# The first positional arg is always a subtype of `StateBase`. Other
# arguments may or may not exist, based on the rule's type (block,
# core, inline). Return type is either `None` or `bool` based on the
# rule's type.
RuleFunc = Callable


@dataclass(**DATACLASS_KWARGS)
class Rule:
    name: str
    enabled: bool
    fn: RuleFunc = field(repr=False)
    alt: list[str]


class Ruler:
    def __init__(self):
        # List of added rules.
        self.__rules__: list[Rule] = []
        # Cached rule chains.
        # First level - chain name, '' for default.
        # Second level - diginal anchor for fast filtering by charcodes.
        self.__cache__: dict[str, list[RuleFunc]] | None = None

    def __find__(self, name: str) -> int:
        """Find rule index by name"""
        for i, rule in enumerate(self.__rules__):
            if rule.name == name:
                return i
        return -1

    def __compile__(self) -> None:
        """Build rules lookup cache"""
        chains = {""}
        # collect unique names
        for rule in self.__rules__:
            if not rule.enabled:
                continue
            for name in rule.alt:
                chains.add(name)
        self.__cache__ = {}
        for chain in chains:
            self.__cache__[chain] = []
            for rule in self.__rules__:
                if not rule.enabled:
                    continue
                if chain and (chain not in rule.alt):
                    continue
                self.__cache__[chain].append(rule.fn)

    def at(self, ruleName: str, fn: RuleFunc, options=None):
        """Replace rule by name with new function & options.

        :param ruleName: rule name to replace.
        :param fn: new rule function.
        :param options: new rule options (not mandatory).
        :raises: KeyError if name not found
        """
        index = self.__find__(ruleName)
        options = options or {}
        if index == -1:
            raise KeyError(f"Parser rule not found: {ruleName}")
        self.__rules__[index].fn = fn
        self.__rules__[index].alt = options.get("alt", [])
        self.__cache__ = None

    def before(self, beforeName: str, ruleName: str, fn: RuleFunc, options=None):
        """Add new rule to chain before one with given name.

        :param beforeName: new rule will be added before this one.
        :param ruleName: new rule will be added before this one.
        :param fn: new rule function.
        :param options: new rule options (not mandatory).
        :raises: KeyError if name not found
        """
        index = self.__find__(beforeName)
        options = options or {}
        if index == -1:
            raise KeyError(f"Parser rule not found: {beforeName}")
        self.__rules__.insert(index, Rule(ruleName, True, fn, options.get("alt", [])))
        self.__cache__ = None

    def after(self, afterName: str, ruleName: str, fn: RuleFunc, options=None):
        """Add new rule to chain after one with given name.

        :param afterName: new rule will be added after this one.
        :param ruleName: new rule will be added after this one.
        :param fn: new rule function.
        :param options: new rule options (not mandatory).
        :raises: KeyError if name not found
        """
        index = self.__find__(afterName)
        options = options or {}
        if index == -1:
            raise KeyError(f"Parser rule not found: {afterName}")
        self.__rules__.insert(
            index + 1, Rule(ruleName, True, fn, options.get("alt", []))
        )
        self.__cache__ = None

    def push(self, ruleName: str, fn: RuleFunc, options=None):
        """Push new rule to the end of chain.

        :param ruleName: new rule will be added to the end of chain.
        :param fn: new rule function.
        :param options: new rule options (not mandatory).

        """
        self.__rules__.append(Rule(ruleName, True, fn, (options or {}).get("alt", [])))
        self.__cache__ = None

    def enable(self, names: str | Iterable[str], ignoreInvalid: bool = False):
        """Enable rules with given names.

        :param names: name or list of rule names to enable.
        :param ignoreInvalid: ignore errors when rule not found
        :raises: KeyError if name not found and not ignoreInvalid
        :return: list of found rule names
        """
        if isinstance(names, str):
            names = [names]
        result = []
        for name in names:
            idx = self.__find__(name)
            if (idx < 0) and ignoreInvalid:
                continue
            if (idx < 0) and not ignoreInvalid:
                raise KeyError(f"Rules manager: invalid rule name {name}")
            self.__rules__[idx].enabled = True
            result.append(name)
        self.__cache__ = None
        return result

    def enableOnly(self, names: str | Iterable[str], ignoreInvalid: bool = False):
        """Enable rules with given names, and disable everything else.

        :param names: name or list of rule names to enable.
        :param ignoreInvalid: ignore errors when rule not found
        :raises: KeyError if name not found and not ignoreInvalid
        :return: list of found rule names
        """
        if isinstance(names, str):
            names = [names]
        for rule in self.__rules__:
            rule.enabled = False
        self.enable(names, ignoreInvalid)

    def disable(self, names: str | Iterable[str], ignoreInvalid: bool = False):
        """Disable rules with given names.

        :param names: name or list of rule names to enable.
        :param ignoreInvalid: ignore errors when rule not found
        :raises: KeyError if name not found and not ignoreInvalid
        :return: list of found rule names
        """
        if isinstance(names, str):
            names = [names]
        result = []
        for name in names:
            idx = self.__find__(name)
            if (idx < 0) and ignoreInvalid:
                continue
            if (idx < 0) and not ignoreInvalid:
                raise KeyError(f"Rules manager: invalid rule name {name}")
            self.__rules__[idx].enabled = False
            result.append(name)
        self.__cache__ = None
        return result

    def getRules(self, chainName: str) -> list[RuleFunc]:
        """Return array of active functions (rules) for given chain name.
        It analyzes rules configuration, compiles caches if not exists and returns result.

        Default chain name is `''` (empty string). It can't be skipped.
        That's done intentionally, to keep signature monomorphic for high speed.

        """
        if self.__cache__ is None:
            self.__compile__()
            assert self.__cache__ is not None
        # Chain can be empty, if rules disabled. But we still have to return Array.
        return self.__cache__.get(chainName, []) or []

    def get_all_rules(self) -> list[str]:
        """Return all available rule names."""
        return [r.name for r in self.__rules__]

    def get_active_rules(self) -> list[str]:
        """Return the active rule names."""
        return [r.name for r in self.__rules__ if r.enabled]