summaryrefslogtreecommitdiffstats
path: root/gfx/skia/skia/src/sksl/codegen/SkSLMetalCodeGenerator.h
blob: 621b99edb411f45fa792092f2bea02efae0129b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
/*
 * Copyright 2016 Google Inc.
 *
 * Use of this source code is governed by a BSD-style license that can be
 * found in the LICENSE file.
 */

#ifndef SKSL_METALCODEGENERATOR
#define SKSL_METALCODEGENERATOR

#include "include/private/SkSLDefines.h"
#include "include/private/base/SkTArray.h"
#include "src/core/SkTHash.h"
#include "src/sksl/SkSLStringStream.h"
#include "src/sksl/codegen/SkSLCodeGenerator.h"
#include "src/sksl/ir/SkSLType.h"

#include <cstdint>
#include <initializer_list>
#include <string>
#include <string_view>
#include <vector>

namespace SkSL {

class AnyConstructor;
class BinaryExpression;
class Block;
class ConstructorArrayCast;
class ConstructorCompound;
class ConstructorMatrixResize;
class Context;
class DoStatement;
class Expression;
class ExpressionStatement;
class Extension;
class FieldAccess;
class ForStatement;
class FunctionCall;
class FunctionDeclaration;
class FunctionDefinition;
class FunctionPrototype;
class IfStatement;
class IndexExpression;
class InterfaceBlock;
class Literal;
class Operator;
class OutputStream;
class Position;
class PostfixExpression;
class PrefixExpression;
class ProgramElement;
class ReturnStatement;
class Statement;
class StructDefinition;
class SwitchStatement;
class Swizzle;
class TernaryExpression;
class VarDeclaration;
class Variable;
class VariableReference;
enum class OperatorPrecedence : uint8_t;
enum IntrinsicKind : int8_t;
struct Layout;
struct Modifiers;
struct Program;

/**
 * Converts a Program into Metal code.
 */
class MetalCodeGenerator : public CodeGenerator {
public:
    MetalCodeGenerator(const Context* context, const Program* program, OutputStream* out)
    : INHERITED(context, program, out)
    , fReservedWords({"atan2", "rsqrt", "rint", "dfdx", "dfdy", "vertex", "fragment"})
    , fLineEnding("\n") {}

    bool generateCode() override;

protected:
    using Precedence = OperatorPrecedence;

    typedef int Requirements;
    inline static constexpr Requirements kNo_Requirements          = 0;
    inline static constexpr Requirements kInputs_Requirement       = 1 << 0;
    inline static constexpr Requirements kOutputs_Requirement      = 1 << 1;
    inline static constexpr Requirements kUniforms_Requirement     = 1 << 2;
    inline static constexpr Requirements kGlobals_Requirement      = 1 << 3;
    inline static constexpr Requirements kFragCoord_Requirement    = 1 << 4;
    inline static constexpr Requirements kThreadgroups_Requirement = 1 << 5;

    class GlobalStructVisitor;
    void visitGlobalStruct(GlobalStructVisitor* visitor);

    class ThreadgroupStructVisitor;
    void visitThreadgroupStruct(ThreadgroupStructVisitor* visitor);

    void write(std::string_view s);

    void writeLine(std::string_view s = std::string_view());

    void finishLine();

    void writeHeader();

    void writeSampler2DPolyfill();

    void writeUniformStruct();

    void writeInputStruct();

    void writeOutputStruct();

    void writeInterfaceBlocks();

    void writeStructDefinitions();

    void writeConstantVariables();

    void writeFields(const std::vector<Type::Field>& fields, Position pos,
                     const InterfaceBlock* parentIntf = nullptr);

    int size(const Type* type, bool isPacked) const;

    int alignment(const Type* type, bool isPacked) const;

    void writeGlobalStruct();

    void writeGlobalInit();

    void writeThreadgroupStruct();

    void writeThreadgroupInit();

    void writePrecisionModifier();

    std::string typeName(const Type& type);

    void writeStructDefinition(const StructDefinition& s);

    void writeType(const Type& type);

    void writeExtension(const Extension& ext);

    void writeInterfaceBlock(const InterfaceBlock& intf);

    void writeFunctionRequirementParams(const FunctionDeclaration& f,
                                        const char*& separator);

    void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);

    bool writeFunctionDeclaration(const FunctionDeclaration& f);

    void writeFunction(const FunctionDefinition& f);

    void writeFunctionPrototype(const FunctionPrototype& f);

    void writeLayout(const Layout& layout);

    void writeModifiers(const Modifiers& modifiers);

    void writeVarInitializer(const Variable& var, const Expression& value);

    void writeName(std::string_view name);

    void writeVarDeclaration(const VarDeclaration& decl);

    void writeFragCoord();

    void writeVariableReference(const VariableReference& ref);

    void writeExpression(const Expression& expr, Precedence parentPrecedence);

    void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);

    std::string getOutParamHelper(const FunctionCall& c,
                                  const ExpressionArray& arguments,
                                  const SkTArray<VariableReference*>& outVars);

    std::string getInversePolyfill(const ExpressionArray& arguments);

    std::string getBitcastIntrinsic(const Type& outType);

    std::string getTempVariable(const Type& varType);

    void writeFunctionCall(const FunctionCall& c);

    bool matrixConstructHelperIsNeeded(const ConstructorCompound& c);
    std::string getMatrixConstructHelper(const AnyConstructor& c);
    void assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns);
    void assembleMatrixFromExpressions(const AnyConstructor& ctor, int rows, int columns);

    void writeMatrixCompMult();

    void writeOuterProduct();

    void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);

    void writeMatrixDivisionHelpers(const Type& type);

    void writeMatrixEqualityHelpers(const Type& left, const Type& right);

    std::string getVectorFromMat2x2ConstructorHelper(const Type& matrixType);

    void writeArrayEqualityHelpers(const Type& type);

    void writeStructEqualityHelpers(const Type& type);

    void writeEqualityHelpers(const Type& leftType, const Type& rightType);

    void writeArgumentList(const ExpressionArray& arguments);

    void writeSimpleIntrinsic(const FunctionCall& c);

    bool writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind);

    void writeConstructorCompound(const ConstructorCompound& c, Precedence parentPrecedence);

    void writeConstructorCompoundVector(const ConstructorCompound& c, Precedence parentPrecedence);

    void writeConstructorCompoundMatrix(const ConstructorCompound& c, Precedence parentPrecedence);

    void writeConstructorMatrixResize(const ConstructorMatrixResize& c,
                                      Precedence parentPrecedence);

    void writeAnyConstructor(const AnyConstructor& c,
                             const char* leftBracket,
                             const char* rightBracket,
                             Precedence parentPrecedence);

    void writeCastConstructor(const AnyConstructor& c,
                              const char* leftBracket,
                              const char* rightBracket,
                              Precedence parentPrecedence);

    void writeConstructorArrayCast(const ConstructorArrayCast& c, Precedence parentPrecedence);

    void writeFieldAccess(const FieldAccess& f);

    void writeSwizzle(const Swizzle& swizzle);

    // Splats a scalar expression across a matrix of arbitrary size.
    void writeNumberAsMatrix(const Expression& expr, const Type& matrixType);

    void writeBinaryExpressionElement(const Expression& expr,
                                      Operator op,
                                      const Expression& other,
                                      Precedence precedence);

    void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);

    void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);

    void writeIndexExpression(const IndexExpression& expr);

    void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);

    void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);

    void writeLiteral(const Literal& f);

    void writeStatement(const Statement& s);

    void writeStatements(const StatementArray& statements);

    void writeBlock(const Block& b);

    void writeIfStatement(const IfStatement& stmt);

    void writeForStatement(const ForStatement& f);

    void writeDoStatement(const DoStatement& d);

    void writeExpressionStatement(const ExpressionStatement& s);

    void writeSwitchStatement(const SwitchStatement& s);

    void writeReturnStatementFromMain();

    void writeReturnStatement(const ReturnStatement& r);

    void writeProgramElement(const ProgramElement& e);

    Requirements requirements(const FunctionDeclaration& f);

    Requirements requirements(const Statement* s);

    // For compute shader main functions, writes and initializes the _in and _out structs (the
    // instances, not the types themselves)
    void writeComputeMainInputs();

    int getUniformBinding(const Modifiers& m);

    int getUniformSet(const Modifiers& m);

    SkTHashSet<std::string_view> fReservedWords;
    SkTHashMap<const Type::Field*, const InterfaceBlock*> fInterfaceBlockMap;
    SkTHashMap<const InterfaceBlock*, std::string_view> fInterfaceBlockNameMap;
    int fAnonInterfaceCount = 0;
    int fPaddingCount = 0;
    const char* fLineEnding;
    std::string fFunctionHeader;
    StringStream fExtraFunctions;
    StringStream fExtraFunctionPrototypes;
    int fVarCount = 0;
    int fIndentation = 0;
    bool fAtLineStart = false;
    // true if we have run into usages of dFdx / dFdy
    bool fFoundDerivatives = false;
    SkTHashMap<const FunctionDeclaration*, Requirements> fRequirements;
    SkTHashSet<std::string> fHelpers;
    int fUniformBuffer = -1;
    std::string fRTFlipName;
    const FunctionDeclaration* fCurrentFunction = nullptr;
    int fSwizzleHelperCount = 0;
    bool fIgnoreVariableReferenceModifiers = false;
    static constexpr char kTextureSuffix[] = "_Tex";
    static constexpr char kSamplerSuffix[] = "_Smplr";

    // Workaround/polyfill flags
    bool fWrittenInverse2 = false, fWrittenInverse3 = false, fWrittenInverse4 = false;
    bool fWrittenMatrixCompMult = false;
    bool fWrittenOuterProduct = false;

    using INHERITED = CodeGenerator;
};

}  // namespace SkSL

#endif