summaryrefslogtreecommitdiffstats
path: root/tools/glsl_preproc/statement.py
blob: 8641e94aa4462c69b6af712a3b437e5a2526fd74 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import re

from templates import GLSL_BLOCK_TEMPLATE
from variables import VarSet, slugify

VAR_PATTERN = re.compile(flags=re.VERBOSE, pattern=r'''
    # long form ${ ... } syntax
    \${ (?:\s*(?P<type>(?:              # optional type prefix
            ident                       # identifiers (always dynamic)
          | (?:(?:const|dynamic)\s+)?   # optional const/dynamic modifiers
                (?:float|u?int)         # base type
          | swizzle                     # swizzle mask
          | (?:i|u)?vecType             # vector type (for mask)
        )):)?
        (?P<expr>[^{}]+)
      }
|   \$(?P<name>\w+) # reference to captured variable
|   @(?P<var>\w+)   # reference to locally defined var
''')

class FmtSpec(object):
    def __init__(self, ctype='ident_t', fmtstr='_%hx',
                 wrap_expr=lambda name, expr: expr,
                 fmt_expr=lambda name: name):
        self.ctype     = ctype
        self.fmtstr    = fmtstr
        self.wrap_expr = wrap_expr
        self.fmt_expr  = fmt_expr

    @staticmethod
    def wrap_var(type, dynamic=False):
        if dynamic:
            return lambda name, expr: f'sh_var_{type}(sh, "{name}", {expr}, true)'
        else:
            return lambda name, expr: f'sh_const_{type}(sh, "{name}", {expr})'

    @staticmethod
    def wrap_fn(fn):
        return lambda name: f'{fn}({name})'

VAR_TYPES = {
    # identifiers: get mapped as-is
    'ident':            FmtSpec(),

    # normal variables: get mapped as shader constants
    'int':              FmtSpec(wrap_expr=FmtSpec.wrap_var('int')),
    'uint':             FmtSpec(wrap_expr=FmtSpec.wrap_var('uint')),
    'float':            FmtSpec(wrap_expr=FmtSpec.wrap_var('float')),

    # constant variables: get printed directly into the source code
    'const int':        FmtSpec(ctype='int',          fmtstr='%d'),
    'const uint':       FmtSpec(ctype='unsigned',     fmtstr='uint(%u)'),
    'const float':      FmtSpec(ctype='float',        fmtstr='float(%f)'),

    # dynamic variables: get loaded as shader variables
    'dynamic int':      FmtSpec(wrap_expr=FmtSpec.wrap_var('int', dynamic=True)),
    'dynamic uint':     FmtSpec(wrap_expr=FmtSpec.wrap_var('uint', dynamic=True)),
    'dynamic float':    FmtSpec(wrap_expr=FmtSpec.wrap_var('float', dynamic=True)),

    # component mask types
    'swizzle':          FmtSpec(ctype='uint8_t', fmtstr='%s', fmt_expr=FmtSpec.wrap_fn('sh_swizzle')),
    'ivecType':         FmtSpec(ctype='uint8_t', fmtstr='%s', fmt_expr=FmtSpec.wrap_fn('sh_float_type')),
    'uvecType':         FmtSpec(ctype='uint8_t', fmtstr='%s', fmt_expr=FmtSpec.wrap_fn('sh_float_type')),
    'vecType':          FmtSpec(ctype='uint8_t', fmtstr='%s', fmt_expr=FmtSpec.wrap_fn('sh_float_type')),
}

def stringify(value, strip):
    end = '\\n"'
    if strip:
        end = '"'
        value = re.sub(r'(?:\/\*[^\*]*\*\/|\/\/[^\n]+|^\s*)', '', value)
    return '"' + value.replace('\\', '\\\\').replace('"', '\\"') + end

def commentify(value, strip):
    if strip:
        return ''
    return '/*' + value.replace('/*', '[[').replace('*/', ']]') + '*/'

# Represents a statement + its enclosed variables
class Statement(object):
    def __init__(self, linenr=0):
        super().__init__()
        self.linenr = linenr
        self.vars = VarSet()

    def add_var(self, ctype, expr, name=None):
        return self.vars.add_var(ctype, expr, name, self.linenr)

    def render(self):
        raise NotImplementedError

    @staticmethod
    def parse(text_orig, **kwargs):
        raise NotImplementedError

# Represents a single line of GLSL
class GLSLLine(Statement):
    class GLSLVar(object): # variable reference
        def __init__(self, fmt, var):
            self.fmt = fmt
            self.var = var

    def __init__(self, text, strip=False, **kwargs):
        super().__init__(**kwargs)
        self.refs = []
        self.strip = strip

        # produce two versions of line, one for printf() and one for append()
        text = text.rstrip()
        self.rawstr = stringify(text, strip)
        self.fmtstr = stringify(re.sub(VAR_PATTERN, self.handle_var, text.replace('%', '%%')), strip)

    def handle_var(self, match):
        # local @var
        if match['var']:
            self.refs.append(match['var'])
            return '%d'

        # captured $var
        type = match['type']
        name = match['name']
        expr = match['expr'] or name
        name = name or slugify(expr)

        fmt = VAR_TYPES[type or 'ident']
        self.refs.append(fmt.fmt_expr(self.add_var(
            ctype = fmt.ctype,
            expr  = fmt.wrap_expr(name, expr),
            name  = name,
        )))

        if fmt.ctype == 'ident_t':
            return commentify(name, self.strip) + fmt.fmtstr
        else:
            return fmt.fmtstr

# Represents an entire GLSL block
class GLSLBlock(Statement):
    def __init__(self, line):
        super().__init__(linenr=line.linenr)
        self.lines = []
        self.refs  = []
        self.append(line)

    def append(self, line):
        assert isinstance(line, GLSLLine)
        self.lines.append(line)
        self.refs += line.refs
        self.vars.merge(line.vars)

    def render(self):
        return GLSL_BLOCK_TEMPLATE.render(block=self)

# Represents a statement which can either take a single line or a block
class BlockStart(Statement):
    def __init__(self, multiline=False, **kwargs):
        super().__init__(**kwargs)
        self.multiline = multiline

    def add_brace(self, text):
        if self.multiline:
            text += ' {'
        return text

# Represents an @if
class IfCond(BlockStart):
    def __init__(self, cond, inner=False, **kwargs):
        super().__init__(**kwargs)
        self.cond = cond if inner else self.add_var('bool', expr=cond)

    def render(self):
        return self.add_brace(f'if ({self.cond})')

# Represents an @else
class Else(BlockStart):
    def __init__(self, closing, **kwargs):
        super().__init__(**kwargs)
        self.closing = closing

    def render(self):
        text = '} else' if self.closing else 'else'
        return self.add_brace(text)

# Represents a normal (integer) @for loop, or an (unsigned 8-bit) bitmask loop
class ForLoop(BlockStart):
    def __init__(self, var, op, bound, **kwargs):
        super().__init__(**kwargs)
        self.comps = op == ':'
        self.bound = self.add_var('uint8_t' if self.comps else 'int', expr=bound)
        self.var   = var
        self.op    = op

    def render(self):
        if self.comps:
            loopstart = f'uint8_t _mask = {self.bound}, {self.var}'
            loopcond  = f'_mask && ({self.var} = __builtin_ctz(_mask), 1)'
            loopstep  = f'_mask &= ~(1u << {self.var})'
        else:
            loopstart = f'int {self.var} = 0'
            loopcond  = f'{self.var} {self.op} {self.bound}'
            loopstep  = f'{self.var}++'

        return self.add_brace(f'for ({loopstart}; {loopcond}; {loopstep})')

# Represents a @switch block
class Switch(Statement):
    def __init__(self, expr, **kwargs):
        super().__init__(**kwargs)
        self.expr = self.add_var('unsigned', expr=expr)

    def render(self):
        return f'switch ({self.expr}) {{'

# Represents a @case label
class Case(Statement):
    def __init__(self, label, **kwargs):
        super().__init__(**kwargs)
        self.label = label

    def render(self):
        return f'case {self.label}:'

# Represents a @default line
class Default(Statement):
    def render(self):
        return 'default:'

# Represents a @break line
class Break(Statement):
    def render(self):
        return 'break;'

# Represents a single closing brace
class EndBrace(Statement):
    def render(self):
        return '}'

# Shitty regex-based statement parser
PATTERN_IF  = re.compile(flags=re.VERBOSE, pattern=r'''
@\s*if\s*                       # '@if'
(?P<inner>@)?                   # optional leading @
\((?P<cond>.+)\)\s*             # (condition)
(?P<multiline>{)?\s*            # optional trailing {
$''')

PATTERN_ELSE = re.compile(flags=re.VERBOSE, pattern=r'''
@\s*(?P<closing>})?\s*          # optional leading }
else\s*                         # 'else'
(?P<multiline>{)?\s*            # optional trailing {
$''')

PATTERN_FOR = re.compile(flags=re.VERBOSE, pattern=r'''
@\s*for\s+\(                    # '@for' (
(?P<var>\w+)\s*                 # loop variable name
(?P<op>(?:\<=?|:))(?=[\w\s])\s* # '<', '<=' or ':', followed by \s or \w
(?P<bound>[^\s].*)\s*           # loop boundary expression
\)\s*(?P<multiline>{)?\s*       # ) and optional trailing {
$''')

PATTERN_SWITCH = re.compile(flags=re.VERBOSE, pattern=r'''
@\s*switch\s*                   # '@switch'
\((?P<expr>.+)\)\s*{            # switch expression
$''')

PATTERN_CASE = re.compile(flags=re.VERBOSE, pattern=r'''
@\s*case\s*                     # '@case'
(?P<label>[^:]+):?              # case label, optionally followed by :
$''')

PATTERN_BREAK   = r'@\s*break;?\s*$'
PATTERN_DEFAULT = r'@\s*default:?\s*$'
PATTERN_BRACE   = r'@\s*}\s*$'

PARSERS = {
    PATTERN_IF:         lambda r, **kw: IfCond(r['cond'], inner=r['inner'], multiline=r['multiline'], **kw),
    PATTERN_ELSE:       lambda r, **kw: Else(closing=r['closing'], multiline=r['multiline'], **kw),
    PATTERN_FOR:        lambda r, **kw: ForLoop(r['var'], r['op'], r['bound'], multiline=r['multiline'], **kw),
    PATTERN_SWITCH:     lambda r, **kw: Switch(r['expr'], **kw),
    PATTERN_CASE:       lambda r, **kw: Case(r['label'], **kw),
    PATTERN_BREAK:      lambda _, **kw: Break(**kw),
    PATTERN_DEFAULT:    lambda _, **kw: Default(**kw),
    PATTERN_BRACE:      lambda _, **kw: EndBrace(**kw),
}

def parse_line(text_orig, strip, **kwargs):
    # skip empty lines
    text = text_orig.strip()
    if not text:
        return None
    if text.lstrip().startswith('@'):
        # try parsing as statement
        for pat, fun in PARSERS.items():
            if res := re.match(pat, text):
                return fun(res, **kwargs)
        # return generic error for unrecognized statements
        raise SyntaxError('Syntax error in directive: ' + text.lstrip())
    else:
        # default to literal GLSL line
        return GLSLLine(text_orig, strip, **kwargs)

Statement.parse = parse_line