# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

# This script generates jit/MIROpsGenerated.h (list of MIR instructions)
# from MIROps.yaml, as well as MIR op definitions.

import buildconfig
import six
import yaml
from mozbuild.preprocessor import Preprocessor

HEADER_TEMPLATE = """\
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef %(includeguard)s
#define %(includeguard)s

/* This file is generated by jit/GenerateMIRFiles.py. Do not edit! */

%(contents)s

#endif // %(includeguard)s
"""


def generate_header(c_out, includeguard, contents):
    c_out.write(
        HEADER_TEMPLATE
        % {
            "includeguard": includeguard,
            "contents": contents,
        }
    )


def load_yaml(yaml_path):
    # First invoke preprocessor.py so that we can use #ifdef JS_SIMULATOR in
    # the YAML file.
    pp = Preprocessor()
    pp.context.update(buildconfig.defines["ALLDEFINES"])
    pp.out = six.StringIO()
    pp.do_filter("substitution")
    pp.do_include(yaml_path)
    contents = pp.out.getvalue()
    return yaml.safe_load(contents)


type_policies = {
    "Object": "ObjectPolicy",
    "Value": "BoxPolicy",
    "Int32": "UnboxedInt32Policy",
    "BigInt": "BigIntPolicy",
    "Boolean": "BooleanPolicy",
    "Double": "DoublePolicy",
    "String": "StringPolicy",
    "Symbol": "SymbolPolicy",
}


def decide_type_policy(types, no_type_policy):
    if no_type_policy:
        return "public NoTypePolicy::Data"

    if len(types) == 1:
        return "public {}<0>::Data".format(type_policies[types[0]])

    type_num = 0
    mixed_type_policies = []
    for mir_type in types:
        policy = type_policies[mir_type]
        mixed_type_policies.append("{}<{}>".format(policy, type_num))
        type_num += 1

    return "public MixPolicy<{}>::Data".format(", ".join(mixed_type_policies))


mir_base_class = [
    "MNullaryInstruction",
    "MUnaryInstruction",
    "MBinaryInstruction",
    "MTernaryInstruction",
    "MQuaternaryInstruction",
]


gc_pointer_types = [
    "JSObject*",
    "NativeObject*",
    "JSFunction*",
    "BaseScript*",
    "PropertyName*",
    "Shape*",
    "GetterSetter*",
    "JSAtom*",
    "ClassBodyScope*",
    "VarScope*",
    "NamedLambdaObject*",
    "RegExpObject*",
    "JSScript*",
    "LexicalScope*",
]


def gen_mir_class(
    name,
    operands,
    arguments,
    no_type_policy,
    result,
    guard,
    movable,
    folds_to,
    congruent_to,
    alias_set,
    might_alias,
    possibly_calls,
    compute_range,
    can_recover,
    clone,
    can_consume_float32,
):
    """Generates class definition for a single MIR opcode."""

    # Generate a MIR opcode class definition.
    # For example:
    # class MGuardIndexIsValidUpdateOrAdd
    #     : public MBinaryInstruction,
    #       public MixPolicy<ObjectPolicy<0>, UnboxedInt32Policy<1>>::Data {
    #  explicit MGuardIndexIsValidUpdateOrAdd(MDefinition* object,
    #                                         MDefinition* index)
    #     : MBinaryInstruction(classOpcode, object, index) {
    #   setGuard();
    #   setMovable();
    #   setResultType(MIRType::Int32);
    #  }
    # public:
    #  INSTRUCTION_HEADER(GetFrameArgument)
    #  TRIVIAL_NEW_WRAPPERS
    #  NAMED_OPERANDS((0, object), (1, index))
    #  AliasSet getAliasSet() const override { return AliasSet::None(); }
    #  bool congruentTo(const MDefinition* ins) const override {
    #    return congruentIfOperandsEqual(ins); }
    #  };
    #

    type_policy = ""
    # MIR op constructor operands.
    mir_operands = []
    # MIR op base class constructor operands.
    mir_base_class_operands = []
    # Types of each constructor operand.
    mir_types = []
    # Items for NAMED_OPERANDS.
    named_operands = []
    if operands:
        current_oper_num = 0
        for oper_name in operands:
            oper = "MDefinition* " + oper_name
            mir_operands.append(oper)
            mir_base_class_operands.append(", " + oper_name)
            # Collect all the MIR argument types to use for determining the
            # ops type policy.
            mir_types.append(operands[oper_name])
            # Collecting named operands for defining accessors.
            named_operands.append("({}, {})".format(current_oper_num, oper_name))
            current_oper_num += 1
        type_policy = decide_type_policy(mir_types, no_type_policy)

    class_name = "M" + name

    assert len(mir_operands) < 5
    base_class = mir_base_class[len(mir_operands)]
    assert base_class
    if base_class != "MNullaryInstruction":
        assert type_policy
        type_policy = ", " + type_policy
    code = "class {} : public {}{} {{\\\n".format(class_name, base_class, type_policy)

    # Arguments to class constructor that require accessors.
    mir_args = []
    if arguments:
        for arg_name in arguments:
            arg_type_sig = arguments[arg_name]
            mir_args.append(arg_type_sig + " " + arg_name)
            if arg_type_sig in gc_pointer_types:
                code += "  CompilerGCPointer<" + arg_type_sig + ">"
            else:
                code += "  " + arg_type_sig
            code += " " + arg_name + "_;\\\n"

    code += "  explicit {}({}) : {}(classOpcode{})".format(
        class_name,
        ", ".join(mir_operands + mir_args),
        base_class,
        "".join(mir_base_class_operands),
    )
    if arguments:
        for arg_name in arguments:
            code += ", " + arg_name + "_(" + arg_name + ")"
    code += " {\\\n"
    if guard:
        code += "    setGuard();\\\n"
    if movable:
        code += "    setMovable();\\\n"
    if result:
        code += "    setResultType(MIRType::{});\\\n".format(result)
    code += "  }\\\n public:\\\n"
    if arguments:
        for arg_name in arguments:
            code += "  " + arguments[arg_name] + " " + arg_name + "() const { "
            code += "return " + arg_name + "_; }\\\n"
    code += "  INSTRUCTION_HEADER({})\\\n".format(name)
    code += "  TRIVIAL_NEW_WRAPPERS\\\n"
    if named_operands:
        code += "  NAMED_OPERANDS({})\\\n".format(", ".join(named_operands))
    if alias_set:
        if alias_set == "custom":
            code += "  AliasSet getAliasSet() const override;\\\n"
        else:
            assert alias_set == "none"
            code += (
                "  AliasSet getAliasSet() const override { "
                "return AliasSet::None(); }\\\n"
            )
    if might_alias:
        code += "  AliasType mightAlias(const MDefinition* store) const override;\\\n"
    if folds_to:
        code += "  MDefinition* foldsTo(TempAllocator& alloc) override;\\\n"
    if congruent_to:
        if congruent_to == "custom":
            code += "  bool congruentTo(const MDefinition* ins) const override;\\\n"
        else:
            assert congruent_to == "if_operands_equal"
            code += (
                "  bool congruentTo(const MDefinition* ins) const override { "
                "return congruentIfOperandsEqual(ins); }\\\n"
            )
    if possibly_calls:
        if possibly_calls == "custom":
            code += "  bool possiblyCalls() const override;\\\n"
        else:
            code += "  bool possiblyCalls() const override { return true; }\\\n"
    if compute_range:
        code += "  void computeRange(TempAllocator& alloc) override;\\\n"
    if can_recover:
        code += "  [[nodiscard]] bool writeRecoverData(\\\n"
        code += "    CompactBufferWriter& writer) const override;\\\n"
        if can_recover == "custom":
            code += "  bool canRecoverOnBailout() const override;\\\n"
        else:
            code += "  bool canRecoverOnBailout() const override { return true; }\\\n"
    if clone:
        code += "  ALLOW_CLONE(" + class_name + ")\\\n"
    if can_consume_float32:
        code += (
            "  bool canConsumeFloat32(MUse* use) const override { return true; }\\\n"
        )
    code += "};\\\n"
    return code


def gen_non_gc_pointer_type_assertions(seen_types):
    """Generates a list of static assertions used to ensure that all argument
    types seen are not derived from gc::Cell, ensuring that gc pointer arguments
    are added to the gc_pointer_types list.
    """
    assertions = []

    for seen_type in sorted(seen_types):
        assertions.append(
            "static_assert(!std::is_base_of_v<gc::Cell, " + seen_type.strip("*") + ">, "
            '"Ensure that '
            + seen_type.strip("*")
            + ' is added to the gc_pointer_types list in GenerateMIRFiles.py."'
            ");"
        )

    return assertions


def generate_mir_header(c_out, yaml_path):
    """Generate MIROpsGenerated.h from MIROps.yaml. The generated file
    has a list of MIR ops and boilerplate for MIR op definitions.
    """

    data = load_yaml(yaml_path)

    # MIR_OPCODE_LIST items. Stores the name of each MIR op.
    ops_items = []

    # Generated MIR op class definitions.
    mir_op_classes = []

    # Unique and non gc pointer types seen for arguments to the MIR constructor.
    seen_non_gc_pointer_argument_types = set()

    for op in data:
        name = op["name"]

        ops_items.append("_({})".format(name))

        gen_boilerplate = op.get("gen_boilerplate", True)
        assert isinstance(gen_boilerplate, bool)

        if gen_boilerplate:
            operands = op.get("operands", None)
            assert operands is None or isinstance(operands, dict)

            arguments = op.get("arguments", None)
            assert arguments is None or isinstance(arguments, dict)

            no_type_policy = op.get("type_policy", None)
            assert no_type_policy in (None, "none")

            result = op.get("result_type", None)
            assert result is None or isinstance(result, str)

            guard = op.get("guard", None)
            assert guard in (None, True, False)

            movable = op.get("movable", None)
            assert movable in (None, True, False)

            folds_to = op.get("folds_to", None)
            assert folds_to in (None, "custom")

            congruent_to = op.get("congruent_to", None)
            assert congruent_to in (None, "if_operands_equal", "custom")

            alias_set = op.get("alias_set", None)
            assert alias_set in (None, "none", "custom")

            might_alias = op.get("might_alias", None)
            assert might_alias in (None, "custom")

            possibly_calls = op.get("possibly_calls", None)
            assert possibly_calls in (None, True, "custom")

            compute_range = op.get("compute_range", None)
            assert compute_range in (None, "custom")

            can_recover = op.get("can_recover", None)
            assert can_recover in (None, True, False, "custom")

            clone = op.get("clone", None)
            assert clone in (None, True, False)

            can_consume_float32 = op.get("can_consume_float32", None)
            assert can_consume_float32 in (None, True, False)

            code = gen_mir_class(
                name,
                operands,
                arguments,
                no_type_policy,
                result,
                guard,
                movable,
                folds_to,
                congruent_to,
                alias_set,
                might_alias,
                possibly_calls,
                compute_range,
                can_recover,
                clone,
                can_consume_float32,
            )
            mir_op_classes.append(code)

            if arguments:
                for argument in arguments:
                    arg_type = arguments[argument]
                    if arg_type not in gc_pointer_types:
                        seen_non_gc_pointer_argument_types.add(arg_type)

    contents = "#define MIR_OPCODE_LIST(_)\\\n"
    contents += "\\\n".join(ops_items)
    contents += "\n\n"

    contents += "#define MIR_OPCODE_CLASS_GENERATED \\\n"
    contents += "\\\n".join(mir_op_classes)
    contents += "\n\n"

    contents += "#define NON_GC_POINTER_TYPE_ASSERTIONS_GENERATED \\\n"
    contents += "\\\n".join(
        gen_non_gc_pointer_type_assertions(seen_non_gc_pointer_argument_types)
    )
    contents += "\n\n"

    generate_header(c_out, "jit_MIROpsGenerated_h", contents)