diff options
Diffstat (limited to 'third_party/intgemm')
59 files changed, 23472 insertions, 0 deletions
diff --git a/third_party/intgemm/CMake/Catch.cmake b/third_party/intgemm/CMake/Catch.cmake new file mode 100644 index 0000000000..486e323318 --- /dev/null +++ b/third_party/intgemm/CMake/Catch.cmake @@ -0,0 +1,175 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +Catch +----- + +This module defines a function to help use the Catch test framework. + +The :command:`catch_discover_tests` discovers tests by asking the compiled test +executable to enumerate its tests. This does not require CMake to be re-run +when tests change. However, it may not work in a cross-compiling environment, +and setting test properties is less convenient. + +This command is intended to replace use of :command:`add_test` to register +tests, and will create a separate CTest test for each Catch test case. Note +that this is in some cases less efficient, as common set-up and tear-down logic +cannot be shared by multiple test cases executing in the same instance. +However, it provides more fine-grained pass/fail information to CTest, which is +usually considered as more beneficial. By default, the CTest test name is the +same as the Catch name; see also ``TEST_PREFIX`` and ``TEST_SUFFIX``. + +.. command:: catch_discover_tests + + Automatically add tests with CTest by querying the compiled test executable + for available tests:: + + catch_discover_tests(target + [TEST_SPEC arg1...] + [EXTRA_ARGS arg1...] + [WORKING_DIRECTORY dir] + [TEST_PREFIX prefix] + [TEST_SUFFIX suffix] + [PROPERTIES name1 value1...] + [TEST_LIST var] + ) + + ``catch_discover_tests`` sets up a post-build command on the test executable + that generates the list of tests by parsing the output from running the test + with the ``--list-test-names-only`` argument. This ensures that the full + list of tests is obtained. Since test discovery occurs at build time, it is + not necessary to re-run CMake when the list of tests changes. + However, it requires that :prop_tgt:`CROSSCOMPILING_EMULATOR` is properly set + in order to function in a cross-compiling environment. + + Additionally, setting properties on tests is somewhat less convenient, since + the tests are not available at CMake time. Additional test properties may be + assigned to the set of tests as a whole using the ``PROPERTIES`` option. If + more fine-grained test control is needed, custom content may be provided + through an external CTest script using the :prop_dir:`TEST_INCLUDE_FILES` + directory property. The set of discovered tests is made accessible to such a + script via the ``<target>_TESTS`` variable. + + The options are: + + ``target`` + Specifies the Catch executable, which must be a known CMake executable + target. CMake will substitute the location of the built executable when + running the test. + + ``TEST_SPEC arg1...`` + Specifies test cases, wildcarded test cases, tags and tag expressions to + pass to the Catch executable with the ``--list-test-names-only`` argument. + + ``EXTRA_ARGS arg1...`` + Any extra arguments to pass on the command line to each test case. + + ``WORKING_DIRECTORY dir`` + Specifies the directory in which to run the discovered test cases. If this + option is not provided, the current binary directory is used. + + ``TEST_PREFIX prefix`` + Specifies a ``prefix`` to be prepended to the name of each discovered test + case. This can be useful when the same test executable is being used in + multiple calls to ``catch_discover_tests()`` but with different + ``TEST_SPEC`` or ``EXTRA_ARGS``. + + ``TEST_SUFFIX suffix`` + Similar to ``TEST_PREFIX`` except the ``suffix`` is appended to the name of + every discovered test case. Both ``TEST_PREFIX`` and ``TEST_SUFFIX`` may + be specified. + + ``PROPERTIES name1 value1...`` + Specifies additional properties to be set on all tests discovered by this + invocation of ``catch_discover_tests``. + + ``TEST_LIST var`` + Make the list of tests available in the variable ``var``, rather than the + default ``<target>_TESTS``. This can be useful when the same test + executable is being used in multiple calls to ``catch_discover_tests()``. + Note that this variable is only available in CTest. + +#]=======================================================================] + +#------------------------------------------------------------------------------ +function(catch_discover_tests TARGET) + cmake_parse_arguments( + "" + "" + "TEST_PREFIX;TEST_SUFFIX;WORKING_DIRECTORY;TEST_LIST" + "TEST_SPEC;EXTRA_ARGS;PROPERTIES" + ${ARGN} + ) + + if(NOT _WORKING_DIRECTORY) + set(_WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") + endif() + if(NOT _TEST_LIST) + set(_TEST_LIST ${TARGET}_TESTS) + endif() + + ## Generate a unique name based on the extra arguments + string(SHA1 args_hash "${_TEST_SPEC} ${_EXTRA_ARGS}") + string(SUBSTRING ${args_hash} 0 7 args_hash) + + # Define rule to generate test list for aforementioned test executable + set(ctest_include_file "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_include-${args_hash}.cmake") + set(ctest_tests_file "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_tests-${args_hash}.cmake") + get_property(crosscompiling_emulator + TARGET ${TARGET} + PROPERTY CROSSCOMPILING_EMULATOR + ) + add_custom_command( + TARGET ${TARGET} POST_BUILD + BYPRODUCTS "${ctest_tests_file}" + COMMAND "${CMAKE_COMMAND}" + -D "TEST_TARGET=${TARGET}" + -D "TEST_EXECUTABLE=$<TARGET_FILE:${TARGET}>" + -D "TEST_EXECUTOR=${crosscompiling_emulator}" + -D "TEST_WORKING_DIR=${_WORKING_DIRECTORY}" + -D "TEST_SPEC=${_TEST_SPEC}" + -D "TEST_EXTRA_ARGS=${_EXTRA_ARGS}" + -D "TEST_PROPERTIES=${_PROPERTIES}" + -D "TEST_PREFIX=${_TEST_PREFIX}" + -D "TEST_SUFFIX=${_TEST_SUFFIX}" + -D "TEST_LIST=${_TEST_LIST}" + -D "CTEST_FILE=${ctest_tests_file}" + -P "${_CATCH_DISCOVER_TESTS_SCRIPT}" + VERBATIM + ) + + file(WRITE "${ctest_include_file}" + "if(EXISTS \"${ctest_tests_file}\")\n" + " include(\"${ctest_tests_file}\")\n" + "else()\n" + " add_test(${TARGET}_NOT_BUILT-${args_hash} ${TARGET}_NOT_BUILT-${args_hash})\n" + "endif()\n" + ) + + if(NOT ${CMAKE_VERSION} VERSION_LESS "3.10.0") + # Add discovered tests to directory TEST_INCLUDE_FILES + set_property(DIRECTORY + APPEND PROPERTY TEST_INCLUDE_FILES "${ctest_include_file}" + ) + else() + # Add discovered tests as directory TEST_INCLUDE_FILE if possible + get_property(test_include_file_set DIRECTORY PROPERTY TEST_INCLUDE_FILE SET) + if (NOT ${test_include_file_set}) + set_property(DIRECTORY + PROPERTY TEST_INCLUDE_FILE "${ctest_include_file}" + ) + else() + message(FATAL_ERROR + "Cannot set more than one TEST_INCLUDE_FILE" + ) + endif() + endif() + +endfunction() + +############################################################################### + +set(_CATCH_DISCOVER_TESTS_SCRIPT + ${CMAKE_CURRENT_LIST_DIR}/CatchAddTests.cmake +) diff --git a/third_party/intgemm/CMake/CatchAddTests.cmake b/third_party/intgemm/CMake/CatchAddTests.cmake new file mode 100644 index 0000000000..2220ce3ac6 --- /dev/null +++ b/third_party/intgemm/CMake/CatchAddTests.cmake @@ -0,0 +1,78 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +set(prefix "${TEST_PREFIX}") +set(suffix "${TEST_SUFFIX}") +set(spec ${TEST_SPEC}) +set(extra_args ${TEST_EXTRA_ARGS}) +set(properties ${TEST_PROPERTIES}) +set(script) +set(suite) +set(tests) + +function(add_command NAME) + set(_args "") + foreach(_arg ${ARGN}) + if(_arg MATCHES "[^-./:a-zA-Z0-9_]") + set(_args "${_args} [==[${_arg}]==]") # form a bracket_argument + else() + set(_args "${_args} ${_arg}") + endif() + endforeach() + set(script "${script}${NAME}(${_args})\n" PARENT_SCOPE) +endfunction() + +# Run test executable to get list of available tests +if(NOT EXISTS "${TEST_EXECUTABLE}") + message(FATAL_ERROR + "Specified test executable '${TEST_EXECUTABLE}' does not exist" + ) +endif() +execute_process( + COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" ${spec} --list-test-names-only + OUTPUT_VARIABLE output + RESULT_VARIABLE result +) +# Catch --list-test-names-only reports the number of tests, so 0 is... surprising +if(${result} EQUAL 0) + message(WARNING + "Test executable '${TEST_EXECUTABLE}' contains no tests!\n" + ) +elseif(${result} LESS 0) + message(FATAL_ERROR + "Error running test executable '${TEST_EXECUTABLE}':\n" + " Result: ${result}\n" + " Output: ${output}\n" + ) +endif() + +string(REPLACE "\n" ";" output "${output}") + +# Parse output +foreach(line ${output}) + set(test ${line}) + # use escape commas to handle properly test cases with commans inside the name + string(REPLACE "," "\\," test_name ${test}) + # ...and add to script + add_command(add_test + "${prefix}${test}${suffix}" + ${TEST_EXECUTOR} + "${TEST_EXECUTABLE}" + "${test_name}" + ${extra_args} + ) + add_command(set_tests_properties + "${prefix}${test}${suffix}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${properties} + ) + list(APPEND tests "${prefix}${test}${suffix}") +endforeach() + +# Create a list of all discovered tests, which users may use to e.g. set +# properties on the tests +add_command(set ${TEST_LIST} ${tests}) + +# Write CTest script +file(WRITE "${CTEST_FILE}" "${script}") diff --git a/third_party/intgemm/CMakeLists.txt b/third_party/intgemm/CMakeLists.txt new file mode 100644 index 0000000000..c9f78fa663 --- /dev/null +++ b/third_party/intgemm/CMakeLists.txt @@ -0,0 +1,136 @@ +cmake_minimum_required(VERSION 3.5) +project(intgemm) +string(ASCII 27 Esc) +set(Orange "${Esc}[33m") +set(ColourReset "${Esc}[m") + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(CMAKE_CXX_STANDARD 11) + +if(MSVC) + add_compile_options(/W4 /WX) +else() + add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-unknown-pragmas) + if (COMPILE_WASM) + # Disabling Pthreads + memory growth warning to be an error for WASM + # Pthreads + memory growth causes JS accessing the wasm memory to be slow + # https://github.com/WebAssembly/design/issues/1271 + add_compile_options(-Wno-error=pthreads-mem-growth) + endif() +endif() + +# Check if compiler supports AVX2 (this should only catch emscripten) +try_compile(INTGEMM_COMPILER_SUPPORTS_AVX2 + ${CMAKE_CURRENT_BINARY_DIR}/compile_tests + ${CMAKE_CURRENT_SOURCE_DIR}/compile_test/avx2.cc) + +# Check if compiler supports AVX512BW +try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512BW + ${CMAKE_CURRENT_BINARY_DIR}/compile_tests + ${CMAKE_CURRENT_SOURCE_DIR}/compile_test/avx512bw.cc) + +# Check if the compiler supports AVX512VNNI +try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + ${CMAKE_CURRENT_BINARY_DIR}/compile_tests + ${CMAKE_CURRENT_SOURCE_DIR}/compile_test/avx512vnni.cc) + +if (NOT INTGEMM_COMPILER_SUPPORTS_AVX2 OR NOT INTGEMM_COMPILER_SUPPORTS_AVX512BW OR NOT INTGEMM_COMPILER_SUPPORTS_AVX512VNNI) + set(UNSUPPORTED "Your compiler is too old to support") + if (NOT INTGEMM_COMPILER_SUPPORTS_AVX2) + set(UNSUPPORTED "${UNSUPPORTED} AVX2") + endif() + if (NOT INTGEMM_COMPILER_SUPPORTS_AVX512BW) + set(UNSUPPORTED "${UNSUPPORTED} AVX512BW") + endif() + if (NOT INTGEMM_COMPILER_SUPPORTS_AVX512VNNI) + set(UNSUPPORTED "${UNSUPPORTED} AVX512VNNI") + endif() + message(WARNING "${Orange}${UNSUPPORTED}. Multiplication will be slower on CPUs that support these instructions. For details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}") +endif() + + +add_library(intgemm STATIC intgemm/intgemm.cc) + +# Generate configure file +configure_file(intgemm/intgemm_config.h.in intgemm/intgemm_config.h) +#Ensure it is included by users. +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(intgemm PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +# This isn't necessary since intgemm uses entirely relative paths but source code depending on it may want to #include <intgemm/intgemm.h> +target_include_directories(intgemm INTERFACE .) + +option(USE_OPENMP "Use OpenMP" OFF) +if (USE_OPENMP) + message(STATUS "Compiling with OpenMP") + find_package(OpenMP) + if (NOT ${OpenMP_CXX_FOUND}) + message(SEND_ERROR "OpenMP requested but C++ support not found") + endif() + add_compile_options(${OpenMP_CXX_FLAGS}) + target_link_libraries(intgemm PUBLIC OpenMP::OpenMP_CXX) +endif() + +if (COMPILE_WASM) + # A compile defintion to compile intgemm on WASM platform + target_compile_definitions(intgemm PUBLIC WASM) +endif() + +option(WORMHOLE "Use WASM wormhole https://bugzilla.mozilla.org/show_bug.cgi?id=1672160" OFF) +if (WORMHOLE) + target_compile_definitions(intgemm PUBLIC INTGEMM_WORMHOLE) +endif() + +option(INTGEMM_CPUID_ENVIRONMENT "Allow INTGEMM_CPUID environment variable to downgrade CPU model, which is mainly for testing." ON) +if (INTGEMM_CPUID_ENVIRONMENT) + target_compile_definitions(intgemm PRIVATE INTGEMM_CPUID_ENVIRONMENT) +endif() + +if(INTGEMM_DONT_BUILD_TESTS) + return() +endif() + +foreach(exe benchmark biasmultiply benchmark_quantizer) + add_executable(${exe} benchmarks/${exe}.cc) + target_link_libraries(${exe} intgemm) +endforeach() + +add_executable(example example.cc) +target_link_libraries(example intgemm) + +add_executable(tests + test/test.cc + + # General tests + test/add127_test.cc + test/multiply_test.cc + test/prepare_b_quantized_transposed.cc + test/prepare_b_transposed.cc + test/quantize_test.cc + test/utils_test.cc + + # Kernels tests + test/kernels/add_bias_test.cc + test/kernels/bitwise_not_test.cc + test/kernels/downcast_test.cc + test/kernels/exp_test.cc + test/kernels/floor_test.cc + test/kernels/multiply_test.cc + test/kernels/quantize_test.cc + test/kernels/relu_test.cc + test/kernels/rescale_test.cc + test/kernels/sigmoid_test.cc + test/kernels/tanh_test.cc + test/kernels/unquantize_test.cc + test/kernels/upcast_test.cc + test/kernels/write_test.cc +) +target_link_libraries(tests intgemm) + +#CTest integration with Catch2 +include(${CMAKE_CURRENT_SOURCE_DIR}/CMake/Catch.cmake) +include(CTest) +catch_discover_tests(tests) diff --git a/third_party/intgemm/LICENSE b/third_party/intgemm/LICENSE new file mode 100644 index 0000000000..0d57f7b940 --- /dev/null +++ b/third_party/intgemm/LICENSE @@ -0,0 +1,70 @@ +MIT License + +Copyright (c) 2017--2019 University of Edinburgh, Nikolay Bogoychev, Mateusz Chudyk, Kenneth Heafield, and Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + +test/3rd_party/catch.hpp +Copyright (c) 2019 Two Blue Cubes Ltd. All rights reserved. +Distributed under the Boost Software License, Version 1.0. (See accompanying +file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + + + +The original 16-bit SSE2 code came from: + +Sharp Models on Dull Hardware: Fast and Accurate Neural Machine Translation Decoding on the CPU by Jacob Devlin +https://arxiv.org/abs/1705.01991 + +Under a license: + +Copyright (c) 2017 Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/third_party/intgemm/README.md b/third_party/intgemm/README.md new file mode 100644 index 0000000000..b8388dc0bf --- /dev/null +++ b/third_party/intgemm/README.md @@ -0,0 +1,91 @@ +[![Build SSE](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/intgemm/job/intgemm-SSE.svg?label=SSE)](http://vali.inf.ed.ac.uk/jenkins/job/intgemm-SSE/) +[![Build AVX2](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/intgemm/job/intgemm-AVX2.svg?label=AVX2)](http://vali.inf.ed.ac.uk/jenkins/job/intgemm-AVX2/) +[![Build AVX512BW](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/intgemm/job/intgemm-AVX512BW.svg?label=AVX512BW)](http://vali.inf.ed.ac.uk/jenkins/job/intgemm-AVX512BW/) +![Build Ubuntu](https://github.com/kpu/intgemm/workflows/Ubuntu/badge.svg) +![Build Ubuntu debug](https://github.com/kpu/intgemm/workflows/Ubuntu%20debug/badge.svg) +![Build Ubuntu OpenMP](https://github.com/kpu/intgemm/workflows/Ubuntu%20OpenMP/badge.svg) +![Build Windows](https://github.com/kpu/intgemm/workflows/Windows/badge.svg) +![Build Mac](https://github.com/kpu/intgemm/workflows/Mac/badge.svg) +[![Intel Compiler](https://github.com/kpu/intgemm/actions/workflows/intel-19.yml/badge.svg)](https://github.com/kpu/intgemm/actions/workflows/intel-19.yml) + +# Integer Matrix Multiplication + +This repository implements 8-bit and 16-bit matrix multiplication: + +C = A * B + +It's designed with neural network inference in mind: A is typically activations, B is typically fixed parameters, and C is activations for the next layer. + +A can have any number of rows. Typically this is a batch size. +The shared dimension, A's columns and B's rows, must be a multiple of 32 (for 16-bit) or 64 (for 8-bit). +B's columns must be a multiple of 8. + +## Accuracy +16-bit multiplication accumulates into 32-bit integers WITHOUT SATURATION (because there is no 32-bit add with saturation). If width is too large (i.e. >2048) or many 16-bit values are large, there is substantial risk of overflow. Choose a smaller quantization multiplier to scale things down or implement periodic upcasting to 64-bit for me. + +8-bit multiplication accumulates into 16-bit integers with saturation. This saturates for larger widths (~1024) and is worst on SSSE3 because it accumulates in fewer values. It's possible to upcast to 32-bit every so often, but this has not been implemented yet. + +## Usage + +A full example appears in [example.cc](example.cc). + +Both A and B should be prepared before multiplication. +```C++ +#include "intgemm/intgemm.h" + +/* Not shown: allocate 64-byte aligned memory with e.g. aligned_alloc. + * A is A_rows x width. + * B is width x B_cols. + */ +/* Prepare A for multiplication. This might be offline or on the fly. */ +intgemm::Int16::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width); +/* Prepare B for multiplication. This is typically done offline. */ +intgemm::Int16::PrepareB(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); +/* Multiply and produce results in C */ +intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0 / (quant_mult * quant_mult), C.begin())); +``` +For 8-bit, use `Int8` instead of `Int16`. + +When repesented as floats, all of A, B, and C are in row-major format. + +The last argument of `Multiply` is a callback which is usually used to performs postprocessing on the output matrix (C). Full set of built-in callbacks can be found in [callbacks/configs.h](callbacks/configs.h). You can also write your own callback. To do that you just need to: +1. Add configuration structure for your callback in [callbacks/configs.h](callbacks/configs.h). +2. Add your callback implementation: + - in [callbacks/implementations.inl](callbacks/implementations.inl) if you want to implement it for all architecturs at the same time. + - in `callbacks/ARCHITECTURE.h` (e.g. [callbacks/sse2.h](callbacks/sse2.h)) if you want to implement it only for the specific architecture. + +For 8-bit, you can make use a of a slightly faster implementation, assuming you can determine tha quantization multipliers and prepare the biases offline: + +```C++ +#include "intgemm/intgemm.h" + +/* Not shown: allocate 64-byte aligned memory with e.g. aligned_alloc. + * A is A_rows x width. + * B is width x B_cols. + * If you want to make use of the slightly faster 8bit codepath (assuming you can cache biases and quantization multipliers) + * This routine only supports C = A*B + Bias + * In practise it computes C = (A+127)*B + Bias - |127|*B + * Prepare A and B first: + */ +float alpha = 25; +float quant_mult = 127/alpha; +intgemm::Int8Shift::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width); +intgemm::Int8Shift::PrepareB(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); +/* Prepare the bias (inplace) */ +float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); +intgemm::Int8Shift::PrepareBias(B_prepared.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); +/* Multiply */ +intgemm::Int8Shift::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), C.begin())); +``` + +## Quantization +Floating-point values are multiplied by a user-specified constant then rounded to an integer. + +In 16 bit, Jacob Devlin recommends 1024.0 for neural networks to prevent the aforementioned overflow. + +In 8 bit, use 127.0 / the largest value (use MaxAbsolute). Quantization will saturate so it's possible to use larger multipliers to obtain clipping. + +## Acknowledgments +The original 16-bit SSE2 code came from: + +Sharp Models on Dull Hardware: Fast and Accurate Neural Machine Translation Decoding on the CPU by Jacob Devlin https://arxiv.org/abs/1705.01991 under the MIT license. diff --git a/third_party/intgemm/benchmarks/benchmark.cc b/third_party/intgemm/benchmarks/benchmark.cc new file mode 100644 index 0000000000..512d3ec39e --- /dev/null +++ b/third_party/intgemm/benchmarks/benchmark.cc @@ -0,0 +1,214 @@ +#include "../intgemm/aligned.h" +#include "intgemm/intgemm_config.h" +#include "../intgemm/avx512_gemm.h" +#include "../intgemm/sse2_gemm.h" +#include "../intgemm/avx2_gemm.h" +#include "../intgemm/ssse3_gemm.h" +#include "../intgemm/intgemm.h" +#include "../intgemm/stats.h" +#include "../intgemm/callbacks.h" + +#include <algorithm> +#include <cassert> +#include <chrono> +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <iomanip> +#include <iostream> +#include <random> + +namespace intgemm { +namespace { + +struct RandomMatrices { + RandomMatrices(Index A_rows_in, Index width_in, Index B_cols_in) : + A_rows(A_rows_in), width(width_in), B_cols(B_cols_in), + A(A_rows * width), B(width * B_cols) { + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.f, 1.f); + gen.seed(45678); + + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + } + + const Index A_rows, width, B_cols; + AlignedVector<float> A, B; +}; + +template <class Backend> double Run(const RandomMatrices &m) { + using Integer = typename Backend::Integer; + float quant_mult = 127.0f / 2.0f; + float unquant_mult = 1.0f / (quant_mult * quant_mult); + AlignedVector<Integer> A_prepared(m.A_rows * m.width); + Backend::PrepareA(m.A.begin(), A_prepared.begin(), quant_mult, m.A_rows, m.width); + AlignedVector<Integer> B_prepared(m.width * m.B_cols); + Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols); + AlignedVector<float> output(m.A_rows * m.B_cols); + // Burn in + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin())); + auto start = std::chrono::steady_clock::now(); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin())); + return std::chrono::duration<double>(std::chrono::steady_clock::now() - start).count(); +} + +template <class Backend> void RunAll(RandomMatrices *matrices, RandomMatrices *matrices_end, std::vector<std::vector<double>> &stats) { + if (Backend::kUses > kCPU) return; + std::size_t size = matrices_end - matrices; + if (stats.size() < size) + stats.resize(size); + for (std::size_t i = 0; i < size; ++i) { + stats[i].push_back(Run<Backend>(matrices[i])); + } +} + +struct BackendStats { + std::vector<std::vector<double>> ssse3_8bit; + std::vector<std::vector<double>> avx2_8bit; + std::vector<std::vector<double>> avx512_8bit; + std::vector<std::vector<double>> avx512vnni_8bit; + std::vector<std::vector<double>> sse2_16bit; + std::vector<std::vector<double>> avx2_16bit; + std::vector<std::vector<double>> avx512_16bit; +}; + +const float kOutlierThreshold = 0.75; +void Summarize(std::vector<double> &stats) { + // Throw out outliers. + std::vector<double>::iterator keep = stats.begin() + static_cast<std::size_t>(static_cast<float>(stats.size()) * kOutlierThreshold); + std::nth_element(stats.begin(), keep, stats.end()); + double avg = 0.0; + for (std::vector<double>::const_iterator i = stats.begin(); i != keep; ++i) { + avg += *i; + } + avg /= (keep - stats.begin()); + double stddev = 0.0; + for (std::vector<double>::const_iterator i = stats.begin(); i != keep; ++i) { + double off = (double)*i - avg; + stddev += off * off; + } + stddev = sqrt(stddev / (keep - stats.begin() - 1)); + std::cout << std::setw(10) << *std::min_element(stats.begin(), stats.end()) << '\t' << std::setw(8) << avg << '\t' << std::setw(8) << stddev; +} + +template <class Backend> void Print(std::vector<std::vector<double>> &stats, std::size_t index) { + if (stats.empty()) return; + std::cout << std::setw(16) << Backend::kName << '\t'; + Summarize(stats[index]); + std::cout << '\n'; +} + +} // namespace intgemm +} // namespace + +// Program takes no input +int main(int, char ** argv) { + std::cerr << "Remember to run this on a specific core:\ntaskset --cpu-list 0 " << argv[0] << std::endl; + + using namespace intgemm; + RandomMatrices matrices[] = { + {1, 64, 8}, + {8, 256, 256}, + {8, 2048, 256}, + {8, 256, 2048}, + {320, 256, 256}, + {472, 256, 256}, + {248, 256, 256}, + {200, 256, 256}, + // Additional stuff + {256, 256, 256}, + {512, 512, 512}, + {1024, 1024, 1024}, +/* {4096, 4096, 4096}, + {4096, 4096, 2048}, + {4096, 4096, 1024}, + {4096, 4096, 512}, + {4096, 4096, 256},*/ + {4096, 4096, 128} + }; + RandomMatrices *matrices_end = (RandomMatrices*)matrices + sizeof(matrices) / sizeof(RandomMatrices); + // Only do full sampling for <1024 rows. + RandomMatrices *full_sample; + for (full_sample = matrices_end - 1; full_sample >= matrices && full_sample->A_rows >= 1024; --full_sample) {} + ++full_sample; + + BackendStats stats; + const int kSamples = 100; + // Realistically, we don't expect different architectures or different precisions to run in the + // same run of an application. Benchmark per architecture and per precision level. + std::cerr << "SSSE3 8bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<SSSE3::Kernels8>(matrices, end, stats.ssse3_8bit); + } + + std::cerr << "SSE2 16bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<SSE2::Kernels16>(matrices, end, stats.sse2_16bit); + } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 + std::cerr << "AVX2 8bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<AVX2::Kernels8>(matrices, end, stats.avx2_8bit); + } + + std::cerr << "AVX2 16bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<AVX2::Kernels16>(matrices, end, stats.avx2_16bit); + } +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + std::cerr << "AVX512 8bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<AVX512BW::Kernels8>(matrices, end, stats.avx512_8bit); + } + + std::cerr << "AVX512 16bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<AVX512BW::Kernels16>(matrices, end, stats.avx512_16bit); + } +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + std::cerr << "AVX512VNNI 8bit, 100 samples..." << std::endl; + for (int samples = 0; samples < kSamples; ++samples) { + RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; + RunAll<AVX512VNNI::Kernels8>(matrices, end, stats.avx512vnni_8bit); + } +#endif + + if (stats.sse2_16bit.empty()) { + std::cerr << "No CPU support." << std::endl; + return 1; + } + for (std::size_t i = 0; i < sizeof(matrices) / sizeof(RandomMatrices); ++i) { + std::cout << "Multiply\t" << matrices[i].A_rows << '\t' << matrices[i].width << '\t' << matrices[i].B_cols << '\t' << "Samples=" << (kOutlierThreshold * stats.sse2_16bit[i].size()) << '\n'; + Print<SSSE3::Kernels8>(stats.ssse3_8bit, i); + Print<AVX2::Kernels8>(stats.avx2_8bit, i); +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + Print<AVX512BW::Kernels8>(stats.avx512_8bit, i); +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + Print<AVX512VNNI::Kernels8>(stats.avx512vnni_8bit, i); +#endif + Print<SSE2::Kernels16>(stats.sse2_16bit, i); + Print<AVX2::Kernels16>(stats.avx2_16bit, i); +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + Print<AVX512BW::Kernels16>(stats.avx512_16bit, i); +#endif + } + return 0; +} + + diff --git a/third_party/intgemm/benchmarks/benchmark_quantizer.cc b/third_party/intgemm/benchmarks/benchmark_quantizer.cc new file mode 100644 index 0000000000..5235b1ea0d --- /dev/null +++ b/third_party/intgemm/benchmarks/benchmark_quantizer.cc @@ -0,0 +1,74 @@ +#include "../intgemm/intgemm.h" +#include "../intgemm/aligned.h" +#include "../intgemm/ssse3_gemm.h" +#include "../intgemm/avx2_gemm.h" +#include "../intgemm/avx512_gemm.h" + +#include <chrono> +#include <iomanip> +#include <iostream> +#include <random> +#include <vector> + +namespace { + +float MaxAbsoluteBaseline(const float *begin, const float *end) { + auto res = std::minmax_element(begin, end); + return std::max(std::fabs(*res.first), std::fabs(*res.second)); +} + +void BenchmarkMaxAbsolute() { + std::mt19937 gen; + std::uniform_real_distribution<float> dist(0.f, 1.f); + gen.seed(45678); + + intgemm::AlignedVector<float> v(4096 * 4096); + for (auto& it : v) { + it = dist(gen); + } + + // Hopefully these don't get optimized out... + MaxAbsoluteBaseline(v.begin(), v.end()); + auto start = std::chrono::steady_clock::now(); + MaxAbsoluteBaseline(v.begin(), v.end()); + double baseline = std::chrono::duration<double>(std::chrono::steady_clock::now() - start).count(); + intgemm::MaxAbsolute(v.begin(), v.end()); + start = std::chrono::steady_clock::now(); + intgemm::MaxAbsolute(v.begin(), v.end()); + double optimized = std::chrono::duration<double>(std::chrono::steady_clock::now() - start).count(); + std::cout << "MaxAbsolute baseline = " << baseline << " optimized = " << optimized << " speedup = " << (optimized / baseline) << '\n'; +} + +template <class Backend> void QuantizerBench(const float *in, int8_t *out, intgemm::Index count) { + if (intgemm::kCPU < Backend::kUses) return; + Backend::Quantize(in, out, 1.0, count); + const std::size_t kTries = 60; + auto start = std::chrono::steady_clock::now(); + for (std::size_t t = 0; t < kTries; ++t) { + Backend::Quantize(in, out, 1.0, count); + } + auto end = std::chrono::steady_clock::now(); + double took = std::chrono::duration<double>(end - start).count() / kTries; + std::cout << std::setw(9) << count << ' ' << std::fixed << std::setw(9) << std::setprecision(7) << took << ' ' << Backend::kName << std::endl; +} +} // namespace + +int main() { + BenchmarkMaxAbsolute(); + for (std::size_t count = 1; count < (1ULL<<30); count *= 2) { + intgemm::AlignedVector<float> in(count); + intgemm::AlignedVector<int8_t> out(count); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-129.0, 129.0); + for (float &element : in) { + element = dist(gen); + } + QuantizerBench<intgemm::SSSE3::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count)); +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 + QuantizerBench<intgemm::AVX2::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count)); +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + QuantizerBench<intgemm::AVX512BW::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count)); +#endif + } +} diff --git a/third_party/intgemm/benchmarks/biasmultiply.cc b/third_party/intgemm/benchmarks/biasmultiply.cc new file mode 100644 index 0000000000..c835b61649 --- /dev/null +++ b/third_party/intgemm/benchmarks/biasmultiply.cc @@ -0,0 +1,278 @@ +#include "../intgemm/intgemm.h" +#include "../intgemm/aligned.h" +#include <chrono> +#include <random> +#include <iostream> + +using namespace intgemm; + +template <class Routine> +void testOld(Index /*rows*/, Index /*cols*/) { +} + +template <class Routine> +std::chrono::duration<double> testNew(Index A_rows, Index width, Index B_cols) { + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127.0f / alpha; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + auto start = std::chrono::system_clock::now(); + Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + auto end = std::chrono::system_clock::now(); + + std::chrono::duration<double> elapsed_seconds = end-start; + return elapsed_seconds; + +} + +template <class Routine> +std::chrono::duration<double> testOld(Index A_rows, Index width, Index B_cols) { + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127.0f / alpha; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<int8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + auto start = std::chrono::system_clock::now(); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + auto end = std::chrono::system_clock::now(); + + std::chrono::duration<double> elapsed_seconds = end-start; + return elapsed_seconds; + +} + +template <class Routine> +std::chrono::duration<double> testOld_nobias(Index A_rows, Index width, Index B_cols) { + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127.0f / alpha; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<int8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + auto start = std::chrono::system_clock::now(); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); + auto end = std::chrono::system_clock::now(); + + std::chrono::duration<double> elapsed_seconds = end-start; + return elapsed_seconds; + +} + +int main(int argc, char ** argv) { + int repeat = 1000; + if (argc > 1) { + repeat = atoi(argv[1]); + } + + std::chrono::duration<double> oldSSSE3_nobias = testOld_nobias<SSSE3::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(8, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(8, 2048, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(320, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(472, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(248, 256, 256); + oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of SSSE3 without bias took: " << oldSSSE3_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldSSSE3 = testOld<SSSE3::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldSSSE3 += testOld<SSSE3::Kernels8>(8, 256, 256); + oldSSSE3 += testOld<SSSE3::Kernels8>(8, 2048, 256); + oldSSSE3 += testOld<SSSE3::Kernels8>(320, 256, 256); + oldSSSE3 += testOld<SSSE3::Kernels8>(472, 256, 256); + oldSSSE3 += testOld<SSSE3::Kernels8>(248, 256, 256); + oldSSSE3 += testOld<SSSE3::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of SSSE3 took: " << oldSSSE3.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeSSSE3 = testOld<SSSE3::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeSSSE3 += testNew<SSSE3::Kernels8>(8, 256, 256); + newTimeSSSE3 += testNew<SSSE3::Kernels8>(8, 2048, 256); + newTimeSSSE3 += testNew<SSSE3::Kernels8>(320, 256, 256); + newTimeSSSE3 += testNew<SSSE3::Kernels8>(472, 256, 256); + newTimeSSSE3 += testNew<SSSE3::Kernels8>(248, 256, 256); + newTimeSSSE3 += testNew<SSSE3::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of Shifted SSSE3 took: " << newTimeSSSE3.count() << " seconds." << std::endl; + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 + std::chrono::duration<double> oldAVX2_nobias = testOld_nobias<AVX2::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(8, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(8, 2048, 256); + oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(320, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(472, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(248, 256, 256); + oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of AVX2 without bias took: " << oldAVX2_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldAVX2 = testOld<AVX2::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX2 += testOld<AVX2::Kernels8>(8, 256, 256); + oldAVX2 += testOld<AVX2::Kernels8>(8, 2048, 256); + oldAVX2 += testOld<AVX2::Kernels8>(320, 256, 256); + oldAVX2 += testOld<AVX2::Kernels8>(472, 256, 256); + oldAVX2 += testOld<AVX2::Kernels8>(248, 256, 256); + oldAVX2 += testOld<AVX2::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of AVX2 took: " << oldAVX2.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeAVX2 = testOld<AVX2::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeAVX2 += testNew<AVX2::Kernels8>(8, 256, 256); + newTimeAVX2 += testNew<AVX2::Kernels8>(8, 2048, 256); + newTimeAVX2 += testNew<AVX2::Kernels8>(320, 256, 256); + newTimeAVX2 += testNew<AVX2::Kernels8>(472, 256, 256); + newTimeAVX2 += testNew<AVX2::Kernels8>(248, 256, 256); + newTimeAVX2 += testNew<AVX2::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of Shifted AVX2 took: " << newTimeAVX2.count() << " seconds." << std::endl; +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + if (kCPU < CPUType::AVX512BW) return 0; + std::chrono::duration<double> oldAVX512_nobias = testOld_nobias<AVX512BW::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(8, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(8, 2048, 256); + oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(320, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(472, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(248, 256, 256); + oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of AVX512 without bias took: " << oldAVX512_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldAVX512 = testOld<AVX512BW::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX512 += testOld<AVX512BW::Kernels8>(8, 256, 256); + oldAVX512 += testOld<AVX512BW::Kernels8>(8, 2048, 256); + oldAVX512 += testOld<AVX512BW::Kernels8>(320, 256, 256); + oldAVX512 += testOld<AVX512BW::Kernels8>(472, 256, 256); + oldAVX512 += testOld<AVX512BW::Kernels8>(248, 256, 256); + oldAVX512 += testOld<AVX512BW::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of AVX512 took: " << oldAVX512.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeAVX512 = testOld<AVX512BW::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeAVX512 += testNew<AVX512BW::Kernels8>(8, 256, 256); + newTimeAVX512 += testNew<AVX512BW::Kernels8>(8, 2048, 256); + newTimeAVX512 += testNew<AVX512BW::Kernels8>(320, 256, 256); + newTimeAVX512 += testNew<AVX512BW::Kernels8>(472, 256, 256); + newTimeAVX512 += testNew<AVX512BW::Kernels8>(248, 256, 256); + newTimeAVX512 += testNew<AVX512BW::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of Shifted AVX512 took: " << newTimeAVX512.count() << " seconds." << std::endl; +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + if (kCPU < CPUType::AVX512VNNI) return 0; + std::chrono::duration<double> oldAVX512VNNI_nobias = testOld_nobias<AVX512BW::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(8, 256, 256); + oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(8, 2048, 256); + oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(320, 256, 256); + oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(472, 256, 256); + oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(248, 256, 256); + oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of AVX512VNNI without bias took: " << oldAVX512VNNI_nobias.count() << " seconds." << std::endl; + + std::chrono::duration<double> oldAVX512VNNI = testOld<AVX512BW::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(8, 256, 256); + oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(8, 2048, 256); + oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(320, 256, 256); + oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(472, 256, 256); + oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(248, 256, 256); + oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of AVX512VNNI took: " << oldAVX512VNNI.count() << " seconds." << std::endl; + + std::chrono::duration<double> newTimeAVX512VNNI = testOld<AVX512BW::Kernels8>(1, 64, 8); + for (int i = 0; i<repeat; i++) { + newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(8, 256, 256); + newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(8, 2048, 256); + newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(320, 256, 256); + newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(472, 256, 256); + newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(248, 256, 256); + newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(200, 256, 256); + } + + std::cout << repeat << " iterations of Shifted AVX512VNNI took: " << newTimeAVX512VNNI.count() << " seconds." << std::endl; +#endif + +} diff --git a/third_party/intgemm/compile_test/avx2.cc b/third_party/intgemm/compile_test/avx2.cc new file mode 100644 index 0000000000..9ed534e929 --- /dev/null +++ b/third_party/intgemm/compile_test/avx2.cc @@ -0,0 +1,25 @@ +// Some compilers don't have AVX2 support. Test for them. +#include <immintrin.h> + +// clang-cl bug doesn't include these headers when pretending to be MSVC +// https://github.com/llvm/llvm-project/blob/e9a294449575a1e1a0daca470f64914695dc9adc/clang/lib/Headers/immintrin.h#L69-L72 +#if defined(_MSC_VER) && defined(__clang__) +#include <avxintrin.h> +#include <avx2intrin.h> +#include <smmintrin.h> +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +#define INTGEMM_AVX2 +#else +#define INTGEMM_AVX2 __attribute__ ((target ("avx2"))) +#endif + +INTGEMM_AVX2 int Test() { + __m256i value = _mm256_set1_epi32(1); + value = _mm256_abs_epi8(value); + return *(int*)&value; +} + +int main() { +} diff --git a/third_party/intgemm/compile_test/avx512bw.cc b/third_party/intgemm/compile_test/avx512bw.cc new file mode 100644 index 0000000000..2361f757d5 --- /dev/null +++ b/third_party/intgemm/compile_test/avx512bw.cc @@ -0,0 +1,31 @@ +// Some compilers don't have AVX512BW support. Test for them. +#include <immintrin.h> + +// clang-cl bug doesn't include these headers when pretending to be MSVC +// https://github.com/llvm/llvm-project/blob/e9a294449575a1e1a0daca470f64914695dc9adc/clang/lib/Headers/immintrin.h#L69-L72 +#if defined(_MSC_VER) && defined(__clang__) +#include <avxintrin.h> +#include <avx2intrin.h> +#include <smmintrin.h> +#include <avx512fintrin.h> +#include <avx512dqintrin.h> +#include <avx512bwintrin.h> +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +#define INTGEMM_AVX512BW +#elif defined(__INTEL_COMPILER) +#define INTGEMM_AVX512BW __attribute__ ((target ("avx512f"))) +#else +#define INTGEMM_AVX512BW __attribute__ ((target ("avx512bw"))) +#endif + +INTGEMM_AVX512BW int Test() { + // AVX512BW + __m512i value = _mm512_set1_epi32(1); + value = _mm512_maddubs_epi16(value, value); + return *(int*)&value; +} + +int main() { +} diff --git a/third_party/intgemm/compile_test/avx512vnni.cc b/third_party/intgemm/compile_test/avx512vnni.cc new file mode 100644 index 0000000000..59035e4778 --- /dev/null +++ b/third_party/intgemm/compile_test/avx512vnni.cc @@ -0,0 +1,36 @@ +#include <immintrin.h> + +// clang-cl bug doesn't include these headers when pretending to be MSVC +// https://github.com/llvm/llvm-project/blob/e9a294449575a1e1a0daca470f64914695dc9adc/clang/lib/Headers/immintrin.h#L69-L72 +#if defined(_MSC_VER) && defined(__clang__) +#include <avxintrin.h> +#include <avx2intrin.h> +#include <smmintrin.h> +#include <avx512fintrin.h> +#include <avx512dqintrin.h> +#include <avx512bwintrin.h> +#include <avx512vnniintrin.h> +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +#elif defined(__INTEL_COMPILER) +__attribute__ ((target ("avx512f"))) +#else +__attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni"))) +#endif +bool Foo() { + // AVX512F + __m512i value = _mm512_set1_epi32(1); + // AVX512BW + value = _mm512_maddubs_epi16(value, value); + // AVX512DQ + __m256i value2 = _mm256_set1_epi8(1); + value = _mm512_inserti32x8(value, value2, 1); + // AVX512VNNI + value = _mm512_dpbusd_epi32(value, value, value); + return *(int*)&value; +} + +int main() { + return Foo(); +} diff --git a/third_party/intgemm/example.cc b/third_party/intgemm/example.cc new file mode 100644 index 0000000000..5f558d0fcd --- /dev/null +++ b/third_party/intgemm/example.cc @@ -0,0 +1,79 @@ +#include "intgemm/intgemm.h" +// This is just for AlignedVector, which helps managed 64-byte aligned memory. +// Feel free to manage memory yourself. +#include "intgemm/aligned.h" +#include "intgemm/callbacks.h" + +#include <cassert> +#include <cmath> +#include <random> + +int main() { + using intgemm::Index; + const Index A_rows = 1; + // The shared dimension: A's columns and B's rows. + const Index width = 64; + const Index B_cols = 8; + + // This is a simple vector class that allocates memory aligned to 64 bytes. + // You don't have to use it; just use aligned_alloc and friends directly. + using intgemm::AlignedVector; + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + + // Fill with random values in range [-2, 2]. + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-2.f, 2.f); + gen.seed(1); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + // Compute the top left corner of C as a sanity check. + float top_left_reference = 0.0f; + for (Index w = 0; w < width; ++w) { + top_left_reference += A[w] * B[w * B_cols]; + } + + // 16-bit multiplication. + { + // For 16-bit, Jacob Devlin recommends 1024 so as to not overflow in 32-bit accumulation. + float quant_mult = 1024.0f; + AlignedVector<int16_t> A_prepared(A.size()); + AlignedVector<int16_t> B_prepared(B.size()); + // Quantize A. + intgemm::Int16::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width); + // Quantize and reshape B. + // Typically you will do this once when parameters are loaded, not every time. + intgemm::Int16::PrepareB(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); + + AlignedVector<float> C(A_rows * B_cols); + // Do the actual multiply. + intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0f / (quant_mult * quant_mult), C.begin())); + // Sanity check. C will be row major. + assert(std::fabs(C[0] - top_left_reference) < 0.05f); + } + + // 8-bit multiplication. + { + // For 8-bit a good quantization multiplier is 127 / largest absolute value.. + float quant_mult = 127.0f / 2.0f; + AlignedVector<int8_t> A_prepared(A.size()); + AlignedVector<int8_t> B_prepared(B.size()); + // Quantize A. + intgemm::Int8::PrepareA(A.begin(), A_prepared.begin(), quant_mult, A_rows, width); + // Quantize and reshape B. + // Typically you will do this once when parameters are loaded, not every time. + intgemm::Int8::PrepareB(B.begin(), B_prepared.begin(), quant_mult, width, B_cols); + + AlignedVector<float> C(A_rows * B_cols); + // Do the actual multiply. + intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0f / (quant_mult * quant_mult), C.begin())); + // Sanity check. C will be row major. + assert(std::fabs(C[0] - top_left_reference) < 0.05f); + } + return 0; +} diff --git a/third_party/intgemm/intgemm/aligned.h b/third_party/intgemm/intgemm/aligned.h new file mode 100644 index 0000000000..6b55ff2558 --- /dev/null +++ b/third_party/intgemm/intgemm/aligned.h @@ -0,0 +1,90 @@ +#pragma once +#include <cstdlib> +#include <new> +#ifdef _MSC_VER +// Ensure _HAS_EXCEPTIONS is defined +#include <vcruntime.h> +#include <malloc.h> +#endif + +#if !((defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)) +#include <cstdlib> +#endif + +// Aligned simple vector. + +namespace intgemm { + +template <class T> class AlignedVector { + public: + AlignedVector() : mem_(nullptr), size_(0) {} + + explicit AlignedVector(std::size_t size, std::size_t alignment = 64 /* CPU cares about this */) + : size_(size) { +#ifdef _MSC_VER + mem_ = static_cast<T*>(_aligned_malloc(size * sizeof(T), alignment)); + if (!mem_) { +# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS) + throw std::bad_alloc(); +# else + std::abort(); +# endif + } +#else + if (posix_memalign(reinterpret_cast<void **>(&mem_), alignment, size * sizeof(T))) { +# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS) + throw std::bad_alloc(); +# else + std::abort(); +# endif + } +#endif + } + + AlignedVector(AlignedVector &&from) : mem_(from.mem_), size_(from.size_) { + from.mem_ = nullptr; + from.size_ = 0; + } + + AlignedVector &operator=(AlignedVector &&from) { + if (this == &from) return *this; + release(); + mem_ = from.mem_; + size_ = from.size_; + from.mem_ = nullptr; + from.size_ = 0; + return *this; + } + + AlignedVector(const AlignedVector&) = delete; + AlignedVector& operator=(const AlignedVector&) = delete; + + ~AlignedVector() { release(); } + + std::size_t size() const { return size_; } + + T &operator[](std::size_t offset) { return mem_[offset]; } + const T &operator[](std::size_t offset) const { return mem_[offset]; } + + T *begin() { return mem_; } + const T *begin() const { return mem_; } + T *end() { return mem_ + size_; } + const T *end() const { return mem_ + size_; } + + template <typename ReturnType> + ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); } + + private: + T *mem_; + std::size_t size_; + + void release() { +#ifdef _MSC_VER + _aligned_free(mem_); +#else + std::free(mem_); +#endif + } +}; + +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/avx2_gemm.h b/third_party/intgemm/intgemm/avx2_gemm.h new file mode 100644 index 0000000000..d93ac8ecdb --- /dev/null +++ b/third_party/intgemm/intgemm/avx2_gemm.h @@ -0,0 +1,232 @@ +#pragma once + +#include "intgemm/intgemm_config.h" + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 + +#include "interleave.h" +#include "kernels.h" +#include "multiply.h" +#include "types.h" + +#include <cstdint> +#include <cstring> + +namespace intgemm { +namespace AVX2 { + +INTGEMM_AVX2 inline Register QuantizerGrab(const float *input, const __m256 quant_mult_reg) { + return kernels::quantize(loadu_ps<FRegister>(input), quant_mult_reg); +} + +INTGEMM_SELECT_COL_B(INTGEMM_AVX2, __m256i) + +class QuantizeTile16 { + public: + INTGEMM_AVX2 static inline Register Consecutive(FRegister mult_reg, const float *input) { + return Tile(mult_reg, input, input + 8); + } + + INTGEMM_AVX2 static inline Register ConsecutiveWithWrapping(FRegister mult_reg, const float *input, Index cols_left, Index cols, Index row_step) { + return Tile(mult_reg, + input, + input + 8 + (cols_left <= 8 ? cols * (row_step - 1) : 0)); + } + + INTGEMM_AVX2 static inline Register ForReshape(FRegister mult_reg, const float *input, Index cols) { + // 8 rows in the first 128-bit register, 8 in the second register. + return Tile(mult_reg, input, input + 8 * cols); + } + + private: + INTGEMM_AVX2 static inline Register Tile(FRegister mult_reg, const float *input0, const float *input1) { + Register g0 = QuantizerGrab(input0, mult_reg); + Register g1 = QuantizerGrab(input1, mult_reg); + Register packed = _mm256_packs_epi32(g0, g1); + // Reorder the packed values because Intel does 0 1 2 3 8 9 10 11 4 5 6 7 12 13 14 15. + // Technically this could be removed if the PrepareB did the same reordering internally. + return _mm256_permute4x64_epi64(packed, 0xd8 /* 0, 2, 1, 3 */); + } +}; + +struct Kernels16 { + typedef int16_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_AVX2 static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + // Just quantize everything in order. + INTGEMM_AVX2 static void Quantize(const float *input, int16_t *output, float quant_mult, Index size) { + assert(size % 16 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 32 == 0); + FRegister q = set1_ps<FRegister>(quant_mult); + const float *end = input + size; + for (; input != end; input += 16, output += 16) { + *reinterpret_cast<__m256i*>(output) = QuantizeTile16::Consecutive(q, input); + } + } + + // Tile size for B; B must be a multiple of this block size. + static const Index kBTileRow = 16; + static const Index kBTileCol = 8; +/* + INTGEMM_AVX2 static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { + PrepareBFor16(input, output, AVX2::QuantizeTile16(quant_mult), rows, cols); + }*/ + INTGEMM_PREPARE_B_16(INTGEMM_AVX2, AVX2::QuantizeTile16) + INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX2, int16_t) + INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX2, AVX2::QuantizeTile16, int16_t) + + INTGEMM_AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + AVX2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end); + } + + INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::AVX2) + + constexpr static const char *const kName = "16-bit AVX2"; + + static const CPUType kUses = CPUType::AVX2; +}; + +/* Read 8 floats at a time from input0, input1, input2, and input3. Quantize + * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate + * the result into one register and return it. + */ +class QuantizeTile8 { + public: + INTGEMM_AVX2 static inline Register Consecutive(FRegister quant_mult, const float *input) { + return Tile(quant_mult, input, input + 8, input + 16, input + 24); + } + + INTGEMM_AVX2 static inline Register ConsecutiveU(FRegister quant_mult, const float *input) { + return TileU(quant_mult, input, input + 8, input + 16, input + 24); + } + + INTGEMM_AVX2 static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) { + const float* inputs[4]; + for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { + while (cols_left < sizeof(Register) / sizeof(float)) { + input += cols * (row_step - 1); + cols_left += cols; + } + inputs[i] = input; + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); + } + return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]); + } + + INTGEMM_AVX2 static inline Register ForReshape(FRegister quant_mult, const float *input, Index cols) { + // Put higher rows in the second half of the register. These will jumble + // around in the same way then conveniently land in the right place. + return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, input + 18 * cols); + } + + INTGEMM_AVX2 static inline __m256i Tile(FRegister quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) { + // Looking at the assembly, gcc has pulled this outside the loops calling this. + const __m256i neg127 = _mm256_set1_epi8(-127); + const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + // Grab 4 registers at a time in 32-bit format. + __m256i g0 = AVX2::QuantizerGrab(input0, quant_mult); + __m256i g1 = AVX2::QuantizerGrab(input1, quant_mult); + __m256i g2 = AVX2::QuantizerGrab(input2, quant_mult); + __m256i g3 = AVX2::QuantizerGrab(input3, quant_mult); + // Pack 32-bit to 16-bit. + __m256i packed0 = _mm256_packs_epi32(g0, g1); + __m256i packed1 = _mm256_packs_epi32(g2, g3); + // Pack 16-bit to 8-bit. + __m256i packed = _mm256_packs_epi16(packed0, packed1); + // Ban -128. + packed = _mm256_max_epi8(packed, neg127); + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 + // Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return _mm256_permutevar8x32_epi32(packed, shuffle_param); + } + + private: + //A version that produces uint8_ts + INTGEMM_AVX2 static inline Register TileU(FRegister quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) { + // Looking at the assembly, gcc has pulled this outside the loops calling this. + const __m256i neg127 = _mm256_set1_epi8(-127); + const __m256i pos127 = _mm256_set1_epi8(127); + const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + // Grab 4 registers at a time in 32-bit format. + __m256i g0 = AVX2::QuantizerGrab(input0, quant_mult); + __m256i g1 = AVX2::QuantizerGrab(input1, quant_mult); + __m256i g2 = AVX2::QuantizerGrab(input2, quant_mult); + __m256i g3 = AVX2::QuantizerGrab(input3, quant_mult); + // Pack 32-bit to 16-bit. + __m256i packed0 = _mm256_packs_epi32(g0, g1); + __m256i packed1 = _mm256_packs_epi32(g2, g3); + // Pack 16-bit to 8-bit. + __m256i packed = _mm256_packs_epi16(packed0, packed1); + // Ban -128. + packed = _mm256_max_epi8(packed, neg127); //Could be removed if we use +128 + packed = _mm256_add_epi8(packed, pos127); + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 + // Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return _mm256_permutevar8x32_epi32(packed, shuffle_param); + } +}; + +struct Kernels8 { + typedef int8_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_AVX2 static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + private: + INTGEMM_QUANTIZE_THREAD(INTGEMM_AVX2) + public: + INTGEMM_QUANTIZE(INTGEMM_AVX2) + + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_AVX2 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, output, quant_mult, rows * cols); + } + + // Just quantize everything in order. + INTGEMM_AVX2 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { + assert(size % 32 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 32 == 0); + FRegister q = set1_ps<FRegister>(quant_mult); + const float *end = input + size; + for (; input != end; input += 32, output += 32) { + *reinterpret_cast<__m256i*>(output) = QuantizeTile8::ConsecutiveU(q, input); + } + } + + // Tile size for B; B must be a multiple of this block size. + static const Index kBTileRow = 32; + static const Index kBTileCol = 8; + + INTGEMM_PREPARE_B_8(INTGEMM_AVX2, AVX2::QuantizeTile8) + INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX2, int8_t) + INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX2, AVX2::QuantizeTile8, int8_t) + + INTGEMM_AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + AVX2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end); + } + + INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::AVX2) + + INTGEMM_MULTIPLY8SHIFT(__m256i, INTGEMM_AVX2, CPUType::AVX2) + + INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2) + + constexpr static const char *const kName = "8-bit AVX2"; + + static const CPUType kUses = CPUType::AVX2; +}; + +} // namespace AVX2 +} // namespace intgemm + +#endif diff --git a/third_party/intgemm/intgemm/avx512_gemm.h b/third_party/intgemm/intgemm/avx512_gemm.h new file mode 100644 index 0000000000..90f67ee5ed --- /dev/null +++ b/third_party/intgemm/intgemm/avx512_gemm.h @@ -0,0 +1,411 @@ +#pragma once + +#include "intgemm/intgemm_config.h" + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + +#include "interleave.h" +#include "kernels.h" +#include "multiply.h" +#include "types.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> + +/* AVX512 implementation. + * This uses INTGEMM_AVX512BW, INTGEMM_AVX512DQ, and might use AVX512VL + * That means it supports mainstream CPUs with AVX512, starting with Skylake + * Xeons. + * It does not support any Knights / Xeon Phi processors. + * + * All memory must be 64-byte aligned. + */ + +namespace intgemm { + +// AVX512 has combined collapse and store instructions: +// _mm512_mask_cvtsepi32_storeu_epi16 +// _mm512_mask_cvtsepi32_storeu_epi8 +// So conversion in memory uses these, but I also implement a wider version for +// rearranging B. + +namespace AVX512BW { + +// Load from memory, multiply, and convert to int32_t. +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_AVX512BW inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) { + return kernels::quantize(loadu_ps<__m512>(input), quant_mult_reg); +} + +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_SELECT_COL_B(INTGEMM_AVX512BW, __m512i) + +// For PrepareB we want to read 8 columns at a time. When converting 32-bit +// floats to 8-bit values, that's 32 bytes of floats. But AVX512 is 64 bytes +// wide so it reads off the edge of the tile. We could expand the tile size +// but then the memory written to won't be contiguous anyway so we'd be doing a +// scatter anyway. Easier to just read the 8 columns we wanted as 256 bits +// concatenate. +INTGEMM_AVX512DQ inline __m512 Concat(const __m256 first, const __m256 second) { + // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway. + return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1); +} + +// Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be controlled independently. +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_AVX512BW inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, const __m512 quant_mult_reg) { + __m512 appended = Concat(loadu_ps<__m256>(input0), loadu_ps<__m256>(input1)); + appended = _mm512_mul_ps(appended, quant_mult_reg); + return _mm512_cvtps_epi32(appended); +} + +// These are only used for reshaping due to the AVX512 instructions +// _mm512_mask_cvtsepi32_storeu_epi16 and _mm512_mask_cvtsepi32_storeu_epi8 +// being used for the quantizer. +class QuantizeTile16 { + public: + INTGEMM_AVX512BW static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) { + auto input0 = input; + auto input1 = input + 16 + (cols_left <= 16 ? cols * (row_step - 1) : 0); + auto g0 = QuantizerGrabHalves(input0, input1, quant_mult); + auto g1 = QuantizerGrabHalves(input0 + 8, input1 + 8, quant_mult); + auto packed = packs_epi32(g0, g1); + return _mm512_permutex_epi64(packed, 0xd8 /* 0, 2, 1, 3 */); + } + + INTGEMM_AVX512BW static inline Register ForReshape(FRegister quant_mult, const float *input, Index cols) { + __m512i g0 = QuantizerGrabHalves(input, input + 16 * cols, quant_mult); + __m512i g1 = QuantizerGrabHalves(input + 8 * cols, input + 24 * cols, quant_mult); + __m512i packed = packs_epi32(g0, g1); + // Permute within 256-bit lanes, so same as INTGEMM_AVX2 + return _mm512_permutex_epi64(packed, 0xd8 /* 0, 2, 1, 3 */); + } +}; + +class QuantizeTile8 { + public: + INTGEMM_AVX512BW static inline Register ConsecutiveWithWrapping(FRegister quant_mult, const float *input, Index cols_left, Index cols, Index row_step) { + static const __m512i neg127 = _mm512_set1_epi8(-127); + static const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + + const float* inputs[4]; + for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { + while (cols_left < sizeof(Register) / sizeof(float)) { + input += cols * (row_step - 1); + cols_left += cols; + } + inputs[i] = input; + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); + } + + auto g0 = QuantizerGrab(inputs[0], quant_mult); + auto g1 = QuantizerGrab(inputs[1], quant_mult); + auto g2 = QuantizerGrab(inputs[2], quant_mult); + auto g3 = QuantizerGrab(inputs[3], quant_mult); + + auto packed0 = packs_epi32(g0, g1); + auto packed1 = packs_epi32(g2, g3); + auto packed = _mm512_packs_epi16(packed0, packed1); + packed = _mm512_max_epi8(packed, neg127); + return _mm512_permutexvar_epi32(shuffle_param, packed); + } + + INTGEMM_AVX512BW static inline __m512i ForReshape(FRegister quant_mult, const float *input, Index cols) { + // TODO: try alternative: _mm512_cvtsepi32_epi8 ? + const __m512i neg127 = _mm512_set1_epi8(-127); + // In reverse order: grabbing the first 32-bit values from each 128-bit register, then the second 32-bit values, etc. + const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + + // 32-bit format. + __m512i g0 = QuantizerGrabHalves(input, input + 2 * cols, quant_mult); + __m512i g1 = QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult); + __m512i g2 = QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult); + __m512i g3 = QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult); + // Pack 32-bit to 16-bit. + __m512i packed0 = packs_epi32(g0, g1); + __m512i packed1 = packs_epi32(g2, g3); + // Pack 16-bit to 8-bit. + __m512i packed = _mm512_packs_epi16(packed0, packed1); + // Ban -128. + packed = _mm512_max_epi8(packed, neg127); + // 0 1 2 3 16 17 18 19 32 33 34 35 48 49 50 51 4 5 6 7 20 21 22 23 36 37 38 39 52 53 54 55 8 9 10 11 24 25 26 27 40 41 42 43 56 57 58 59 12 13 14 15 28 29 30 31 44 45 46 47 60 61 62 63 + return _mm512_permutexvar_epi32(shuffle_param, packed); + } +}; + +struct Kernels16 { + typedef int16_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + // rows * cols must be a multiple of 16. + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + // Technically output can be unaligned in Quantize. + // But then it will need to be aligned for Multiply. + // size must be a multiple of 16. + // Convert to 16-bit signed integers. + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static void Quantize(const float *input, int16_t *output, float quant_mult, Index size) { + assert(size % 16 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); + // Fill with the quantization multiplier. + const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); + const float *end = input + size; + for (; input != end; input += 16, output += 16) { + // There doesn't seem to be an unmasked version. + _mm512_mask_cvtsepi32_storeu_epi16(output, 0xffff, QuantizerGrab(input, quant_mult_reg)); + } + } + + + // Tile size for B; B must be a multiple of this block size. + static const Index kBTileRow = 32; + static const Index kBTileCol = 8; + + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_PREPARE_B_16(INTGEMM_AVX512BW, QuantizeTile16) + INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX512BW, int16_t) + INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX512BW, QuantizeTile16, int16_t) + + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end); + } + + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + + constexpr static const char *const kName = "16-bit AVX512"; + + static const CPUType kUses = CPUType::AVX512BW; +}; + +struct Kernels8 { + typedef int8_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + private: + /* g++ (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 does not carry target attributes + * to the hidden function it creates in implementing #pragma omp parallel for. + * So intrinstics were not working inside the for loop when compiled with + * OMP. Also, passing register types across #pragma omp parallel for + * generated an internal compiler error. + * The problem does not occur in g++-8 (Ubuntu 8.3.0-6ubuntu1~18.04.1) 8.3.0. + * As a workaround, I split into #pragma omp parallel with boring types + * passed across the boundary then call this function with target attributes. + */ + INTGEMM_AVX512BW static void QuantizeThread(const float *input, int8_t *output, float quant_mult, std::size_t count) { + const __m512i neg127 = _mm512_set1_epi32(-127); + const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); + const std::size_t kBatch = sizeof(__m512i) / sizeof(float); +#pragma omp for + for (std::size_t i = 0; i < count; i += kBatch) { + __m512i asint = QuantizerGrab(input + i, quant_mult_reg); + asint = _mm512_max_epi32(asint, neg127); + // There doesn't seem to be an unmasked version. + _mm512_mask_cvtsepi32_storeu_epi8(output + i, 0xffff, asint); + } + } + + public: + // Technically output can be unaligned in Quantize. + // But then it will need to be aligned for Multiply. + // Convert to 8-bit signed integers. + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) { + assert(reinterpret_cast<uintptr_t>(input) % sizeof(__m512i) == 0); + const std::size_t kBatch = sizeof(__m512i) / sizeof(float); + std::size_t fast_size = (size & ~(kBatch - 1)); + const float *fast_input_end = input + fast_size; + int8_t *fast_output_end = output + fast_size; +#pragma omp parallel + { + QuantizeThread(input, output, quant_mult, fast_size); + } + std::size_t overhang = size & (kBatch - 1); + if (!overhang) return; // We needed a branch anyway for the empty case. + const __m512i neg127 = _mm512_set1_epi32(-127); + const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); + __m512i asint = QuantizerGrab(fast_input_end, quant_mult_reg); + asint = _mm512_max_epi32(asint, neg127); + _mm512_mask_cvtsepi32_storeu_epi8(fast_output_end, (1 << overhang) - 1, asint); + } + + // Preparing A for the signed/unsigned multiplication. Using add 127 + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, output, quant_mult, rows * cols); + } + + // Technically output can be unaligned in Quantize. + // But then it will need to be aligned for Multiply. + // Convert to 8-bit signed integers. + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + + INTGEMM_AVX512BW static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { + assert(size % 16 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); + const __m512i pos127 = _mm512_set1_epi32(127); + const __m512i zero = _mm512_setzero_si512(); + const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); + const float *end = input + size; + for (; input < end; input += 16, output += 16) { + __m512i asint = QuantizerGrab(input, quant_mult_reg); + asint = _mm512_min_epi32(asint, pos127); + asint = _mm512_add_epi32(asint, pos127); + asint = _mm512_max_epi32(asint, zero); + _mm512_mask_cvtusepi32_storeu_epi8(output, 0xffff, asint); + } + } + + // Tile size for B; B must be a multiple of this block size. + static const Index kBTileRow = 64; + static const Index kBTileCol = 8; + + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_PREPARE_B_8(INTGEMM_AVX512BW, QuantizeTile8) + INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX512BW, int8_t) + INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX512BW, QuantizeTile8, int8_t) + + /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ + INTGEMM_AVX512BW static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end); + } + + // Special AVX512 implementation due to having 32 registers (so I don't have to + // allocate registers manually) and no sign instruction. + template <typename Callback> + INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + // This is copy-paste from Multiply8_SSE2OrAVX2. + assert(width % sizeof(Register) == 0); + assert(B_cols % 8 == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + // There's 8 results for INTGEMM_AVX2 to handle. + auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); + const Index simd_width = width / sizeof(Register); + // Added for AVX512. + Register zeros = setzero_si<Register>(); + // Go over 8 columns of B at a time. +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; + // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + // Iterate over shared (inner) dimension. + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; + + // Do the first iteration to initialize the sums. + __m512i a = *A_live; + __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); + __m512i a_positive = _mm512_abs_epi8(a); + // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A. + Register sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); + Register sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); + Register sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); + Register sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); + Register sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); + Register sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); + Register sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); + Register sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); + + ++A_live; + B_live += 8; + + // Use A as the loop variable so the add can be done where gcc likes it + // for branch prediction. + for (; A_live != A_end; ++A_live, B_live += 8) { + // Unique code here: can we do an inline function? + // Retrieve a. We will use this as the unsigned part. + a = *A_live; + // Retrieve the conveniently consecutive values of B. + __m512i b0 = *B_live; + __m512i b1 = *(B_live + 1); + __m512i b2 = *(B_live + 2); + __m512i b3 = *(B_live + 3); + __m512i b4 = *(B_live + 4); + __m512i b5 = *(B_live + 5); + __m512i b6 = *(B_live + 6); + __m512i b7 = *(B_live + 7); + + // Get a mask where a is negative. + // Didn't seem to make a difference definining sign bits here vs at top + neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); + a_positive = _mm512_abs_epi8(a); + + // Negate by subtracting from zero with a mask. + b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); + b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); + b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2); + b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3); + b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4); + b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5); + b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6); + b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7); + // The magic 8-bit multiply then horizontal sum into 16-bit. + b0 = _mm512_maddubs_epi16(a_positive, b0); + b1 = _mm512_maddubs_epi16(a_positive, b1); + b2 = _mm512_maddubs_epi16(a_positive, b2); + b3 = _mm512_maddubs_epi16(a_positive, b3); + b4 = _mm512_maddubs_epi16(a_positive, b4); + b5 = _mm512_maddubs_epi16(a_positive, b5); + b6 = _mm512_maddubs_epi16(a_positive, b6); + b7 = _mm512_maddubs_epi16(a_positive, b7); + // Now we have 16-bit results that are the sum of two multiplies. + // Choosing to approximate and do adds. + // Perhaps every so often we could accumulate by upcasting. + sum0 = _mm512_adds_epi16(sum0, b0); + sum1 = _mm512_adds_epi16(sum1, b1); + sum2 = _mm512_adds_epi16(sum2, b2); + sum3 = _mm512_adds_epi16(sum3, b3); + sum4 = _mm512_adds_epi16(sum4, b4); + sum5 = _mm512_adds_epi16(sum5, b5); + sum6 = _mm512_adds_epi16(sum6, b6); + sum7 = _mm512_adds_epi16(sum7, b7); + // Unique code ends: can we do an inline function? + } + // Upcast to 32-bit and horizontally add. + Register ones = set1_epi16<Register>(1); + sum0 = madd_epi16(sum0, ones); + sum1 = madd_epi16(sum1, ones); + sum2 = madd_epi16(sum2, ones); + sum3 = madd_epi16(sum3, ones); + sum4 = madd_epi16(sum4, ones); + sum5 = madd_epi16(sum5, ones); + sum6 = madd_epi16(sum6, ones); + sum7 = madd_epi16(sum7, ones); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + + auto total = PermuteSummer(pack0123, pack4567); + callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); + } + } + } + + INTGEMM_MULTIPLY8SHIFT(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + + INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + + constexpr static const char *const kName = "8-bit AVX512BW"; + + static const CPUType kUses = CPUType::AVX512BW; +}; + +} // namespace AVX512BW +} // namespace intgemm + +#endif diff --git a/third_party/intgemm/intgemm/avx512vnni_gemm.h b/third_party/intgemm/intgemm/avx512vnni_gemm.h new file mode 100644 index 0000000000..28e8c14dda --- /dev/null +++ b/third_party/intgemm/intgemm/avx512vnni_gemm.h @@ -0,0 +1,168 @@ +#pragma once + +#include "intgemm/intgemm_config.h" + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +#include "avx512_gemm.h" +#include "types.h" + +namespace intgemm { +namespace AVX512VNNI { + +// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663 +INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) { +#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b)); +#else + c = _mm512_dpbusds_epi32(c, a, b); +#endif +} + +struct Kernels8 : public AVX512BW::Kernels8 { + template <typename Callback> + INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + assert(width % sizeof(Register) == 0); + assert(B_cols % 8 == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); + const Index simd_width = width / sizeof(Register); + Register zeros = setzero_si<Register>(); + // Go over 8 columns of B at a time. +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; + // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + // Iterate over shared (inner) dimension. + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; + // TODO: separate first step. + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + for (; A_live != A_end; ++A_live, B_live += 8) { + Register a = *A_live; + // Retrieve the conveniently consecutive values of B. + Register b0 = *B_live; + Register b1 = *(B_live + 1); + Register b2 = *(B_live + 2); + Register b3 = *(B_live + 3); + Register b4 = *(B_live + 4); + Register b5 = *(B_live + 5); + Register b6 = *(B_live + 6); + Register b7 = *(B_live + 7); + // Get a mask where a is negative. + __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); + Register a_positive = _mm512_abs_epi8(a); + // Negate by subtracting from zero with a mask. + b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); + b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); + b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2); + b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3); + b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4); + b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5); + b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6); + b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7); + VNNI8(sum0, a_positive, b0); + VNNI8(sum1, a_positive, b1); + VNNI8(sum2, a_positive, b2); + VNNI8(sum3, a_positive, b3); + VNNI8(sum4, a_positive, b4); + VNNI8(sum5, a_positive, b5); + VNNI8(sum6, a_positive, b6); + VNNI8(sum7, a_positive, b7); + } + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + auto total = PermuteSummer(pack0123, pack4567); + callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); + } + } + } + + template <typename Callback> + INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + assert(width % sizeof(Register) == 0); + assert(B_cols % 8 == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); + const Index simd_width = width / sizeof(Register); + Register zeros = setzero_si<Register>(); + // Go over 8 columns of B at a time. +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; + // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + // Iterate over shared (inner) dimension. + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; + // TODO: separate first step. + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + for (; A_live != A_end; ++A_live, B_live += 8) { + Register a = *A_live; + //MultiplyAdd + VNNI8(sum0, a, *B_live); + VNNI8(sum1, a, *(B_live + 1)); + VNNI8(sum2, a, *(B_live + 2)); + VNNI8(sum3, a, *(B_live + 3)); + VNNI8(sum4, a, *(B_live + 4)); + VNNI8(sum5, a, *(B_live + 5)); + VNNI8(sum6, a, *(B_live + 6)); + VNNI8(sum7, a, *(B_live + 7)); + } + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + auto total = PermuteSummer(pack0123, pack4567); + callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); + } + } + } + + template <typename Callback> + INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { + assert(width % sizeof(Register) == 0); + assert(B_cols % 8 == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); + Index simd_width = width / sizeof(Register); + Register zeros = setzero_si<Register>(); + const Register a = set1_epi8<Register>(1); + // Go over 8 columns of B at a time. +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; + const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function + const Register *B_end = B_live + simd_width*8; + + // TODO: separate first step. + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + for (; B_live != B_end; B_live += 8) { + // Retrieve the conveniently consecutive values of B. + VNNI8(sum0, a, *B_live); + VNNI8(sum1, a, *(B_live + 1)); + VNNI8(sum2, a, *(B_live + 2)); + VNNI8(sum3, a, *(B_live + 3)); + VNNI8(sum4, a, *(B_live + 4)); + VNNI8(sum5, a, *(B_live + 5)); + VNNI8(sum6, a, *(B_live + 6)); + VNNI8(sum7, a, *(B_live + 7)); + } + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + auto total = PermuteSummer(pack0123, pack4567); + callback_impl.Run(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols)); + } + } + + constexpr static const char *const kName = "8-bit AVX512VNNI"; + + static const CPUType kUses = CPUType::AVX512VNNI; +}; + +} // namespace AVX512VNNI +} // namespace intgemm + +#endif diff --git a/third_party/intgemm/intgemm/callbacks.h b/third_party/intgemm/intgemm/callbacks.h new file mode 100644 index 0000000000..c304466111 --- /dev/null +++ b/third_party/intgemm/intgemm/callbacks.h @@ -0,0 +1,28 @@ +#pragma once + +#include "callbacks/configs.h" +#include "callbacks/output_buffer_info.h" + +#include "intgemm/intgemm_config.h" +#include "intrinsics.h" +#include "kernels.h" +#include "types.h" +#include "utils.h" +#include "vec_traits.h" + +#define CALLBACKS_THIS_IS_SSE2 +#include "callbacks/implementations.inl" +#undef CALLBACKS_THIS_IS_SSE2 + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +#define CALLBACKS_THIS_IS_AVX2 +#include "callbacks/implementations.inl" +#undef CALLBACKS_THIS_IS_AVX2 +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +#define CALLBACKS_THIS_IS_AVX512BW +#include "callbacks/implementations.inl" +#undef CALLBACKS_THIS_IS_AVX512BW +#endif + diff --git a/third_party/intgemm/intgemm/callbacks/configs.h b/third_party/intgemm/intgemm/callbacks/configs.h new file mode 100644 index 0000000000..d2fbe98de7 --- /dev/null +++ b/third_party/intgemm/intgemm/callbacks/configs.h @@ -0,0 +1,73 @@ +#pragma once + +#include <tuple> + +namespace intgemm { +namespace callbacks { + +/* + * Sequence meta-config + */ +template <typename... Configs> +std::tuple<Configs...> Sequence(const Configs&... configs) { + return std::make_tuple(configs...); +} + +/* + * Configs + */ +struct Dummy { +}; + +template <typename Type> +struct Write { + Type* output_addr; + + Write(Type* output_addr) : output_addr(output_addr) {} +}; + +struct Unquantize { + float unquant_mult; + + Unquantize(float unquant_mult) : unquant_mult(unquant_mult) {} +}; + +struct UnquantizeAndWrite { + float unquant_mult; + float* output_addr; + + UnquantizeAndWrite(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {} +}; + +struct UnquantizeAndWriteRelu { + float unquant_mult; + float* output_addr; + + UnquantizeAndWriteRelu(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {} +}; + +struct AddBiasAndWrite { + const int* bias_addr; + int* output_addr; + + AddBiasAndWrite(const int* bias_addr, int* output_addr) : bias_addr(bias_addr), output_addr(output_addr) {} +}; + +struct UnquantizeAndAddBiasAndWrite { + float unquant_mult; + const float* bias_addr; + float* output_addr; + + UnquantizeAndAddBiasAndWrite(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {} +}; + +struct UnquantizeAndAddBiasAndWriteRelu { + float unquant_mult; + const float* bias_addr; + float* output_addr; + + UnquantizeAndAddBiasAndWriteRelu(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {} +}; + +} +} diff --git a/third_party/intgemm/intgemm/callbacks/implementations.inl b/third_party/intgemm/intgemm/callbacks/implementations.inl new file mode 100644 index 0000000000..126701ddc3 --- /dev/null +++ b/third_party/intgemm/intgemm/callbacks/implementations.inl @@ -0,0 +1,258 @@ +/* This file is included multiple times, once per architecture. */ +#if defined(CALLBACKS_THIS_IS_SSE2) + #define CPU_NAME SSE2 + #define INTGEMM_TARGET INTGEMM_SSE2 +#elif defined(CALLBACKS_THIS_IS_AVX2) + #define CPU_NAME AVX2 + #define INTGEMM_TARGET INTGEMM_AVX2 +#elif defined(CALLBACKS_THIS_IS_AVX512BW) + #define CPU_NAME AVX512BW + #define INTGEMM_TARGET INTGEMM_AVX512BW +#else + #error "Only SSE2, AVX2 and AVX512BW are supported" +#endif + +#if defined(CALLBACKS_THIS_IS_SSE2) + #define vi vector_t<CPUType::SSE2, int> + #define vf vector_t<CPUType::SSE2, float> + #define vd vector_t<CPUType::SSE2, double> +#else + #define vi vector_t<CPUType::AVX2, int> + #define vf vector_t<CPUType::AVX2, float> + #define vd vector_t<CPUType::AVX2, double> +#endif + +/* Intel compiler 19.1.0.166 20191121 fails to link constructors with target attributes */ +#ifdef __INTEL_COMPILER +#define INTGEMM_TARGET_CONSTRUCTOR +#else +#define INTGEMM_TARGET_CONSTRUCTOR INTGEMM_TARGET +#endif + +namespace intgemm { +namespace callbacks { + +template <CPUType CpuType, typename CallbackConfig> +class CallbackImpl; + +}} + +/* + * Callbacks implementations.... + */ +namespace intgemm { +namespace callbacks { + +/* + * Sequence + */ +template <typename... Configs> +class CallbackImpl<CPUType::CPU_NAME, std::tuple<Configs...>> { +public: + explicit CallbackImpl(const std::tuple<Configs...>& configs) : callbacks(init_callbacks(configs, make_sequence<sizeof...(Configs)>())) {} + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>()); + } + +private: + using CallbacksTupleType = std::tuple<CallbackImpl<CPUType::CPU_NAME, Configs>...>; + + CallbacksTupleType callbacks; + + template <unsigned... Indices> + CallbacksTupleType init_callbacks(const std::tuple<Configs...>& configs, sequence<Indices...>) { + return std::make_tuple(CallbackImpl<CPUType::CPU_NAME, typename std::tuple_element<Indices, std::tuple<Configs...>>::type>(std::get<Indices>(configs))...); + } + +#define RUN_CALLBACKS_PIPELINE_IMPL(vtype) \ + template <unsigned FirstIndex> \ + INTGEMM_TARGET static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex>) { \ + std::get<FirstIndex>(tuple)(input, info); \ + } \ + template <unsigned FirstIndex, unsigned SecondIndex, unsigned... RestIndices> \ + INTGEMM_TARGET static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex, SecondIndex, RestIndices...>) { \ + auto output = std::get<FirstIndex>(tuple)(input, info); \ + run_callbacks(output, info, tuple, sequence<SecondIndex, RestIndices...>()); \ + } + + RUN_CALLBACKS_PIPELINE_IMPL(vi) + RUN_CALLBACKS_PIPELINE_IMPL(vf) + RUN_CALLBACKS_PIPELINE_IMPL(vd) + +#undef RUN_CALLBACKS_PIPELINE_IMPL +}; + +/* + * Dummy + */ +template <> class CallbackImpl<CPUType::CPU_NAME, Dummy> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Dummy&) {} + INTGEMM_TARGET void Run(vi, const OutputBufferInfo&) {} +}; + +/* + * Write + */ +template <typename Type> +class CallbackImpl<CPUType::CPU_NAME, Write<Type>> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Write<Type>& config) : config(config) {} + + INTGEMM_TARGET void Run(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo& info) { + kernels::write(input, config.output_addr, info.row_idx * info.cols + info.col_idx); + } + +private: + Write<Type> config; +}; + +/* + * Unquantize + */ +template <> class CallbackImpl<CPUType::CPU_NAME, Unquantize> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Unquantize& config) : config(config) { + unquant_mult = set1_ps<vf>(config.unquant_mult); + } + + INTGEMM_TARGET vf Run(vi input, const OutputBufferInfo&) { + return kernels::unquantize(input, unquant_mult); + } + +private: + vf unquant_mult; + Unquantize config; +}; + +/* + * UnquantizeAndWrite + */ +template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWrite& config) : config(config) { + unquant_mult = set1_ps<vf>(config.unquant_mult); + } + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } + +private: + vf unquant_mult; + UnquantizeAndWrite config; +}; + +/* + * UnquantizeAndWriteRelu + */ +template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWriteRelu> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWriteRelu& config) : config(config) { + unquant_mult = set1_ps<vf>(config.unquant_mult); + } + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::relu<float>(kernels::unquantize(input, mult_reg)); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } + +private: + vf unquant_mult; + UnquantizeAndWriteRelu config; +}; + + +/* + * AddBiasAndWrite + */ +template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const AddBiasAndWrite& config) : config(config) {} + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + auto result = kernels::add_bias(input, config.bias_addr, info.col_idx); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } + +private: + AddBiasAndWrite config; +}; + +/* + * UnquantizeAndAddBiasAndWrite + */ +template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) { + unquant_mult = set1_ps<vf>(config.unquant_mult); + } + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + result = kernels::add_bias(result, config.bias_addr, info.col_idx); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } +private: + vf unquant_mult; + UnquantizeAndAddBiasAndWrite config; +}; + +/* + * UnquantizeAndAddBiasAndWrite + */ +template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWriteRelu> { +public: + explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWriteRelu& config) : config(config) { + unquant_mult = set1_ps<vf>(config.unquant_mult); + } + + INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) { + // Workaround gcc 5 internal compiler error that can't read register members in debug. + vf mult_reg; +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult)); +#else + mult_reg = unquant_mult; +#endif + auto result = kernels::unquantize(input, mult_reg); + result = kernels::add_bias(result, config.bias_addr, info.col_idx); + result = kernels::relu<float>(result); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } +private: + vf unquant_mult; + UnquantizeAndAddBiasAndWriteRelu config; +}; + +} +} + +#undef CPU_NAME +#undef INTGEMM_TARGET +#undef vi +#undef vf +#undef vd diff --git a/third_party/intgemm/intgemm/callbacks/output_buffer_info.h b/third_party/intgemm/intgemm/callbacks/output_buffer_info.h new file mode 100644 index 0000000000..213aef4afc --- /dev/null +++ b/third_party/intgemm/intgemm/callbacks/output_buffer_info.h @@ -0,0 +1,20 @@ +#pragma once + +#include "../types.h" + +namespace intgemm { +namespace callbacks { + +struct OutputBufferInfo { + Index row_idx; + Index col_idx; + + Index rows; // = A_rows + Index cols; // = B_cols + + OutputBufferInfo(Index row_idx, Index col_idx, Index rows, Index cols) + : row_idx(row_idx), col_idx(col_idx), rows(rows), cols(cols) {} +}; + +} +} diff --git a/third_party/intgemm/intgemm/interleave.h b/third_party/intgemm/intgemm/interleave.h new file mode 100644 index 0000000000..95f05cebd9 --- /dev/null +++ b/third_party/intgemm/intgemm/interleave.h @@ -0,0 +1,317 @@ +#pragma once + +#include "intgemm/intgemm_config.h" +#include "intrinsics.h" +#include "types.h" + +#include <algorithm> +#include <cassert> + +namespace intgemm { + +/* + * Interleave vectors. + */ +#define INTGEMM_INTERLEAVE_N(target, type, N) \ +target static inline void Interleave##N(type &first, type &second) { \ + type temp = unpacklo_epi##N(first, second); \ + second = unpackhi_epi##N(first, second); \ + first = temp; \ +} + +#define INTGEMM_INTERLEAVE(target, type) \ +INTGEMM_INTERLEAVE_N(target, type, 8) \ +INTGEMM_INTERLEAVE_N(target, type, 16) \ +INTGEMM_INTERLEAVE_N(target, type, 32) \ +INTGEMM_INTERLEAVE_N(target, type, 64) + +INTGEMM_INTERLEAVE(INTGEMM_SSE2, __m128i) + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_INTERLEAVE(INTGEMM_AVX2, __m256i) +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +INTGEMM_INTERLEAVE(INTGEMM_AVX512BW, __m512i) +#endif + +/* + * Swap vectors. + */ +#define INTGEMM_SWAP(target, Register) \ +target static inline void Swap(Register &a, Register &b) { \ + Register tmp = a; \ + a = b; \ + b = tmp; \ +} \ + +INTGEMM_SWAP(INTGEMM_SSE2, __m128i) +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_SWAP(INTGEMM_AVX2, __m256i) +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_SWAP(INTGEMM_AVX512BW, __m512i) +#endif + +/* Transpose registers containing 8 packed 16-bit integers. + * Each 128-bit lane is handled independently. + */ +#define INTGEMM_TRANSPOSE16(target, Register) \ +target static inline void Transpose16InLane(Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7) { \ + /* r0: columns 0 1 2 3 4 5 6 7 from row 0 + r1: columns 0 1 2 3 4 5 6 7 from row 1*/ \ + Interleave16(r0, r1); \ + Interleave16(r2, r3); \ + Interleave16(r4, r5); \ + Interleave16(r6, r7); \ + /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1 + r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1 + r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3 + r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3 + r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5 + r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5 + r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7 + r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ \ + Interleave32(r0, r2); \ + Interleave32(r1, r3); \ + Interleave32(r4, r6); \ + Interleave32(r5, r7); \ + /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3 + r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3 + r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3 + r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3 + r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7 + r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7 + r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7 + r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ \ + Interleave64(r0, r4); \ + Interleave64(r1, r5); \ + Interleave64(r2, r6); \ + Interleave64(r3, r7); \ + /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7 + r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7 + r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7 + r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7 + r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7 + r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ \ + /* Empirically gcc is able to remove these movs and just rename the outputs of Interleave64. */ \ + Swap(r1, r4); \ + Swap(r3, r6); \ +} \ + +INTGEMM_TRANSPOSE16(INTGEMM_SSE2, __m128i) +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_TRANSPOSE16(INTGEMM_AVX2, __m256i) +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_TRANSPOSE16(INTGEMM_AVX512BW, __m512i) +#endif + +/* Tranpose registers containing 16 packed 8-bit integers. + * Each 128-bit lane is handled independently. + */ +template <class Register> static inline void Transpose8InLane( + Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7, + Register &r8, Register &r9, Register &r10, Register &r11, Register &r12, Register &r13, Register &r14, Register &r15) { + // Get 8-bit values to 16-bit values so they can travel together. + Interleave8(r0, r1); + // r0: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 0 and 1. + // r1: columns 8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 from rows 0 and 1. + Interleave8(r2, r3); + // r2: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 2 and 3. + Interleave8(r4, r5); + Interleave8(r6, r7); + Interleave8(r8, r9); + Interleave8(r10, r11); + Interleave8(r12, r13); + Interleave8(r14, r15); + Transpose16InLane(r0, r2, r4, r6, r8, r10, r12, r14); + Transpose16InLane(r1, r3, r5, r7, r9, r11, r13, r15); + // Permute into correct order. This is free because the outputs just get pemuted. + Register tmp; + tmp = r2; + r2 = r4; + r4 = r8; + r8 = r1; + r1 = tmp; + tmp = r3; + r3 = r6; + r6 = r12; + r12 = r9; + r9 = tmp; + tmp = r5; + r5 = r10; + r10 = tmp; + tmp = r7; + r7 = r14; + r14 = r13; + r13 = r11; + r11 = tmp; +} + +// PREPARE B: quantize and rearrange. B is presumed to be constantparameters +// so we can take our time rearranging it in order to save during the multiply. +// +// We presume B starts in row-major order. +// +// In INTGEMM_AVX2, a register holds 32 8-bit values or 16 16-bit values and we want +// that many values from the same column in the register. +// +// The multiplier reads 8 rows at a time and we want these reads to be +// contiguous. +// +// Each 8x32 (for 8-bit) or 8x16 (for 16-bit) tile of B is transposed. +// The tiles are stored in column major order. +// +// For INTGEMM_AVX2, this matrix shows what index each value of B will be stored at: +// 0 16 ... 240 +// 1 17 ... 241 +// 2 18 ... 242 +// 3 19 ... 243 +// 4 20 ... 244 +// 5 21 ... 245 +// 6 22 ... 246 +// 7 23 ... 247 +// 8 24 ... 248 +// 9 25 ... 249 +// 10 26 ... 250 +// 11 27 ... 251 +// 12 28 ... 252 +// 13 29 ... 253 +// 14 30 ... 254 +// 15 31 ... 255 +// 256 272 +// 257 273 +// ... ... +#define INTGEMM_PREPARE_B_8(target, QuantClass) \ +target static inline void PrepareB(const float *input, int8_t *output_shadow, float quant_mult, Index rows, Index cols) { \ + FRegister q = set1_ps<FRegister>(quant_mult); \ + /* Currently all multipliers have a stride of 8 columns.*/ \ + const Index kColStride = 8; \ + assert(cols % kColStride == 0); \ + assert(rows % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \ + Register *output = reinterpret_cast<Register*>(output_shadow); \ + assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \ + for (Index c = 0; c < cols; c += kColStride) { \ + for (Index r = 0; r < rows; r += sizeof(Register), output += 8) { \ + /* Quantize and perform a transpose with height sizeof(Register) and width 8. \ + This isn't quite Transpose8InLane because it's half the number of columns, \ + so each register starts with two rows instead of being one row. \ + The quantizers know to skip a row.*/ \ + output[0] = QuantClass::ForReshape(q, input + cols * (r ) + c, cols); \ + output[1] = QuantClass::ForReshape(q, input + cols * (r + 1) + c, cols); \ + output[2] = QuantClass::ForReshape(q, input + cols * (r + 4) + c, cols); \ + output[3] = QuantClass::ForReshape(q, input + cols * (r + 5) + c, cols); \ + output[4] = QuantClass::ForReshape(q, input + cols * (r + 8) + c, cols); \ + output[5] = QuantClass::ForReshape(q, input + cols * (r + 9) + c, cols); \ + output[6] = QuantClass::ForReshape(q, input + cols * (r + 12) + c, cols); \ + output[7] = QuantClass::ForReshape(q, input + cols * (r + 13) + c, cols); \ + Interleave8(output[0], output[1]); \ + Interleave8(output[2], output[3]); \ + Interleave8(output[4], output[5]); \ + Interleave8(output[6], output[7]); \ + Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \ + } \ + } \ +} \ + +#define INTGEMM_PREPARE_B_16(target, QuantClass) \ +target static inline void PrepareB(const float *input, int16_t *output_shadow, float quant_mult, Index rows, Index cols) { \ + FRegister q = set1_ps<FRegister>(quant_mult); \ + assert(cols % 8 == 0); \ + assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0); \ + assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \ + Register *output = reinterpret_cast<Register*>(output_shadow); \ + assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \ + for (Index c = 0; c < cols; c += 8) { \ + for (Index r = 0; r < rows; r += (sizeof(Register) / sizeof(int16_t)), output += 8) { \ + /* gcc unrolls this loop and uses registers for output[k]*/ \ + for (Index k = 0; k < 8; ++k) { \ + output[k] = QuantClass::ForReshape(q, input + cols * (r + k) + c, cols); \ + } \ + Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); \ + } \ + } \ +} + +/* + * Prepare B matrix. + * B matrix has to be transposed and quantized. + * Cols has to be a multiple of sizeof(Register) / sizeof(Integer). + * + * cols and rows describe size of transposed B. + */ +#define INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(target, Integer) \ +target static inline void PrepareBQuantizedTransposed(const Integer* input, Integer* output, Index cols, Index rows) { \ + const Index RegisterElems = sizeof(Register) / sizeof(Integer); \ + const Index kColStride = 8; \ + \ + assert(cols % RegisterElems == 0); \ + assert(rows % kColStride == 0); \ + assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \ + \ + Register* output_it = reinterpret_cast<Register*>(output); \ + for (Index r = 0; r < rows; r += kColStride) \ + for (Index c = 0; c < cols; c += RegisterElems) \ + for (Index ri = 0; ri < 8; ++ri) \ + *output_it++ = *reinterpret_cast<const Register*>(input + (r + ri) * cols + c); \ +} + +/* + * Prepare B matrix. + * B matrix has to be transposed. + * Cols has to be a multiple of sizeof(Register) / sizeof(float). + * + * cols and rows describe size of transposed B. + */ +#define INTGEMM_PREPARE_B_TRANSPOSED(target, Quantizer, Integer) \ +target static inline void PrepareBTransposed(const float* input, Integer* output, float quant_mult, Index cols, Index rows) { \ + const Index RegisterElemsInt = sizeof(Register) / sizeof(Integer); \ + const Index kColStride = 8; \ + \ + assert(cols % (sizeof(Register) / sizeof(float)) == 0); \ + assert(rows % kColStride == 0); \ + assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \ + \ + FRegister q = set1_ps<FRegister>(quant_mult); \ + Register* output_it = reinterpret_cast<Register*>(output); \ + Index r = 0; \ + Index c = 0; \ + while (r < rows) { \ + for (Index ri = 0; ri < 8; ++ri) \ + *output_it++ = Quantizer::ConsecutiveWithWrapping(q, input + (r + ri) * cols + c, cols - c, cols, 8); \ + c += RegisterElemsInt; \ + while (c >= cols) { \ + r += kColStride; \ + c -= cols; \ + } \ + } \ +} + +/* Select columns of B from PrepareB format to PrepareB format. + */ +#define INTGEMM_SELECT_COL_B(target, Register) \ +target static inline void SelectColumnsOfB(const Register *input, Register *output, Index rows_bytes /* number of bytes in a row */, const Index *cols_begin, const Index *cols_end) { \ + assert(rows_bytes % sizeof(Register) == 0); \ + assert((cols_end - cols_begin) % 8 == 0); \ + /* Do columns for multiples of 8.*/ \ + Index register_rows = rows_bytes / sizeof(Register); \ + const Register *starts[8]; \ + for (; cols_begin != cols_end; cols_begin += 8) { \ + for (Index k = 0; k < 8; ++k) { \ + starts[k] = input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; \ + } \ + for (Index r = 0; r < register_rows; ++r) { \ + for (Index k = 0; k < 8; ++k) { \ + *(output++) = *starts[k]; \ + starts[k] += 8; \ + } \ + } \ + } \ +} + +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/intgemm.cc b/third_party/intgemm/intgemm/intgemm.cc new file mode 100644 index 0000000000..d6c26b93b4 --- /dev/null +++ b/third_party/intgemm/intgemm/intgemm.cc @@ -0,0 +1,207 @@ +#if defined(WASM) +// No header for CPUID since it's hard-coded. +#elif defined(__INTEL_COMPILER) +#include <immintrin.h> +#elif defined(_MSC_VER) +#include <intrin.h> +#else +// Assume GCC and clang style. +#include <cpuid.h> +#endif + +#include "intgemm.h" +#include "stats.h" + +#include <stdio.h> +#include <stdlib.h> + +namespace intgemm { + +namespace { + +// Return the maximum CPU model that's found and supported at compile time. +CPUType RealCPUID() { +#if defined(WASM) + // emscripten does SSE4.1 but we only use up to SSSE3. + return CPUType::SSSE3; +#elif defined(__INTEL_COMPILER) +# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + if (_may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)) return CPUType::AVX512VNNI; +# endif +# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + if (_may_i_use_cpu_feature(_FEATURE_AVX512BW)) return CPUType::AVX512BW; +# endif +# ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 + if (_may_i_use_cpu_feature(_FEATURE_AVX2)) return CPUType::AVX2; +# endif + if (_may_i_use_cpu_feature(_FEATURE_SSSE3)) return CPUType::SSSE3; + if (_may_i_use_cpu_feature(_FEATURE_SSE2)) return CPUType::SSE2; + return CPUType::UNSUPPORTED; +#else +// Not emscripten, not Intel compiler +# if defined(_MSC_VER) + int regs[4]; + int &eax = regs[0], &ebx = regs[1], &ecx = regs[2], &edx = regs[3]; + __cpuid(regs, 0); + int m = eax; +# else + /* gcc and clang. + * If intgemm is compiled by gcc 6.4.1 then dlopened into an executable + * compiled by gcc 7.3.0, there will be a undefined symbol __cpu_info. + * Work around this by calling the intrinsics more directly instead of + * __builtin_cpu_supports. + * + * clang 6.0.0-1ubuntu2 supports vnni but doesn't have + * __builtin_cpu_supports("avx512vnni") + * so use the hand-coded CPUID for clang. + */ + unsigned int m = __get_cpuid_max(0, 0); + unsigned int eax, ebx, ecx, edx; +# endif + if (m >= 7) { +# if defined(_MSC_VER) + __cpuid(regs, 7); +# else + __cpuid_count(7, 0, eax, ebx, ecx, edx); +# endif +# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + if (ecx & (1 << 11)) return CPUType::AVX512VNNI; +# endif +# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + if (ebx & (1 << 30)) return CPUType::AVX512BW; +# endif +# ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 + if (ebx & (1 << 5)) return CPUType::AVX2; +# endif + } + if (m >= 1) { +# if defined(_MSC_VER) + __cpuid(regs, 1); +# else + __cpuid_count(1, 0, eax, ebx, ecx, edx); +# endif + if (ecx & (1 << 9)) return CPUType::SSSE3; + if (edx & (1 << 26)) return CPUType::SSE2; + } + return CPUType::UNSUPPORTED; +#endif +} + +#ifdef INTGEMM_CPUID_ENVIRONMENT +CPUType EnvironmentCPUID() { +# if defined(_MSC_VER) + char env_override[11]; + size_t len = 0; + if (getenv_s(&len, env_override, sizeof(env_override), "INTGEMM_CPUID")) return CPUType::AVX512VNNI; + if (!len) return CPUType::AVX512VNNI; +# else + const char *env_override = getenv("INTGEMM_CPUID"); + if (!env_override) return CPUType::AVX512VNNI; /* This will be capped to actual ID */ +# endif + if (!strcmp(env_override, "AVX512VNNI")) return CPUType::AVX512VNNI; + if (!strcmp(env_override, "AVX512BW")) return CPUType::AVX512BW; + if (!strcmp(env_override, "AVX2")) return CPUType::AVX2; + if (!strcmp(env_override, "SSSE3")) return CPUType::SSSE3; + if (!strcmp(env_override, "SSE2")) return CPUType::SSE2; + fprintf(stderr, "Ignoring unrecognized INTGEMM_CPUID %s\n", env_override); + return CPUType::AVX512VNNI; +} +#endif + +} // namespace + +CPUType GetCPUID() { + static const CPUType kLocalCPU = +#ifdef INTGEMM_CPUID_ENVIRONMENT + std::min(RealCPUID(), EnvironmentCPUID()); +#else + RealCPUID(); +#endif + return kLocalCPU; +} + +const CPUType kCPU = GetCPUID(); + +void UnsupportedCPUError() { +#if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS) + throw UnsupportedCPU(); +#else + fprintf(stderr, "intgemm does not support this CPU.\n"); + abort(); +#endif +} + +float Unsupported_MaxAbsolute(const float * /*begin*/, const float * /*end*/) { + UnsupportedCPUError(); + return 0.0f; +} + +MeanStd Unsupported_VectorMeanStd(const float * /*begin*/, const float * /*end*/, bool /*absolute*/) { + UnsupportedCPUError(); + return MeanStd(); +} + +void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512BW::Kernels16::Quantize, AVX512BW::Kernels16::Quantize, AVX2::Kernels16::Quantize, SSE2::Kernels16::Quantize, SSE2::Kernels16::Quantize, Unsupported_16bit::Quantize); + +void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512BW::Kernels16::PrepareB, AVX512BW::Kernels16::PrepareB, AVX2::Kernels16::PrepareB, SSE2::Kernels16::PrepareB, SSE2::Kernels16::PrepareB, Unsupported_16bit::PrepareB); + +void (*Int16::PrepareBQuantizedTransposed)(const int16_t *input, int16_t *output, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels16::PrepareBQuantizedTransposed, AVX512BW::Kernels16::PrepareBQuantizedTransposed, AVX2::Kernels16::PrepareBQuantizedTransposed, SSE2::Kernels16::PrepareBQuantizedTransposed, SSE2::Kernels16::PrepareBQuantizedTransposed, Unsupported_16bit::PrepareBQuantizedTransposed); + +void (*Int16::PrepareBTransposed)(const float *input, int16_t *output, float quant_mult, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels16::PrepareBTransposed, AVX512BW::Kernels16::PrepareBTransposed, AVX2::Kernels16::PrepareBTransposed, SSE2::Kernels16::PrepareBTransposed, SSE2::Kernels16::PrepareBTransposed, Unsupported_16bit::PrepareBTransposed); + +void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512BW::Kernels16::SelectColumnsB, AVX512BW::Kernels16::SelectColumnsB, AVX2::Kernels16::SelectColumnsB, SSE2::Kernels16::SelectColumnsB, SSE2::Kernels16::SelectColumnsB, Unsupported_16bit::SelectColumnsB); + +const char *const Int16::kName = ChooseCPU(AVX512BW::Kernels16::kName, AVX512BW::Kernels16::kName, AVX2::Kernels16::kName, SSE2::Kernels16::kName, SSE2::Kernels16::kName, Unsupported_16bit::kName); + +void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512VNNI::Kernels8::Quantize, AVX512BW::Kernels8::Quantize, AVX2::Kernels8::Quantize, SSSE3::Kernels8::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); + +void (*Int8::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512VNNI::Kernels8::QuantizeU, AVX512BW::Kernels8::QuantizeU, AVX2::Kernels8::QuantizeU, SSSE3::Kernels8::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU); + +void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512VNNI::Kernels8::PrepareB, AVX512BW::Kernels8::PrepareB, AVX2::Kernels8::PrepareB, SSSE3::Kernels8::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); + +void (*Int8::PrepareBQuantizedTransposed)(const int8_t *input, int8_t *output, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels8::PrepareBQuantizedTransposed, AVX512BW::Kernels8::PrepareBQuantizedTransposed, AVX2::Kernels8::PrepareBQuantizedTransposed, SSSE3::Kernels8::PrepareBQuantizedTransposed, Unsupported_8bit::PrepareBQuantizedTransposed, Unsupported_8bit::PrepareBQuantizedTransposed); + +void (*Int8::PrepareBTransposed)(const float *input, int8_t *output, float quant_mult, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels8::PrepareBTransposed, AVX512BW::Kernels8::PrepareBTransposed, AVX2::Kernels8::PrepareBTransposed, SSSE3::Kernels8::PrepareBTransposed, Unsupported_8bit::PrepareBTransposed, Unsupported_8bit::PrepareBTransposed); + +void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512VNNI::Kernels8::SelectColumnsB, AVX512BW::Kernels8::SelectColumnsB, AVX2::Kernels8::SelectColumnsB, SSSE3::Kernels8::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB); + +const char *const Int8::kName = ChooseCPU(AVX512VNNI::Kernels8::kName, AVX512BW::Kernels8::kName, AVX2::Kernels8::kName, SSSE3::Kernels8::kName, Unsupported_8bit::kName, Unsupported_8bit::kName); + +void (*Int8Shift::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512VNNI::Kernels8::QuantizeU, AVX512BW::Kernels8::QuantizeU, AVX2::Kernels8::QuantizeU, SSSE3::Kernels8::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU); + +const char *const Int8Shift::kName = ChooseCPU(AVX512VNNI::Kernels8::kName, AVX512BW::Kernels8::kName, AVX2::Kernels8::kName, SSSE3::Kernels8::kName, Unsupported_8bit::kName, Unsupported_8bit::kName); + +#if !defined(INTGEMM_COMPILER_SUPPORTS_AVX2) +namespace AVX2{ +using SSE2::MaxAbsolute; +using SSE2::VectorMeanStd; +} // namespace AVX2 +#endif +#if !defined(INTGEMM_COMPILER_SUPPORTS_AVX512BW) +namespace AVX512BW { +using AVX2::MaxAbsolute; +using AVX2::VectorMeanStd; +} // namespace AVX512BW +#endif + +float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(AVX512BW::MaxAbsolute, AVX512BW::MaxAbsolute, AVX2::MaxAbsolute, SSE2::MaxAbsolute, SSE2::MaxAbsolute, Unsupported_MaxAbsolute); + +MeanStd (*VectorMeanStd)(const float *begin, const float *end, bool absolute) = ChooseCPU(AVX512BW::VectorMeanStd, AVX512BW::VectorMeanStd, AVX2::VectorMeanStd, SSE2::VectorMeanStd, SSE2::VectorMeanStd, Unsupported_VectorMeanStd); + +constexpr const char *const Unsupported_16bit::kName; +constexpr const char *const Unsupported_8bit::kName; +constexpr const char *const SSE2::Kernels16::kName; +constexpr const char *const SSSE3::Kernels8::kName; +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +constexpr const char *const AVX2::Kernels8::kName; +constexpr const char *const AVX2::Kernels16::kName; +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +constexpr const char *const AVX512BW::Kernels8::kName; +constexpr const char *const AVX512BW::Kernels16::kName; +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +constexpr const char *const AVX512VNNI::Kernels8::kName; +#endif + +} diff --git a/third_party/intgemm/intgemm/intgemm.h b/third_party/intgemm/intgemm/intgemm.h new file mode 100644 index 0000000000..2528fdbfe4 --- /dev/null +++ b/third_party/intgemm/intgemm/intgemm.h @@ -0,0 +1,365 @@ +#pragma once +/* Main interface for integer matrix multiplication. + * + * We are computing C = A * B with an optional scaling factor. + * + * A is typically activations. + * Rows a multiple of 1 (no restriction) + * Columns a multiple of 64 for 8-bit or 32 for 16-bit. + * Use PrepareA to prepare A for multiplication. This is meant to be fast. + * + * B is typically fixed model parameters. + * Rows a multiple of 64 for 8-bit or 32 for 16-bit. + * Columns a multiple of: 8 + * Use PrepareB to prepare B for multiplication. This is slower, with the + * intention that it will be prepared once and remembered. + * + * C is row major. + * + * Once both A and B are prepared, call Multiply. + * + * All memory (A, B, and C in float or prepared form) must be 64-byte aligned. + * It's easy to write code that works on your CPU with lower alignment, but + * breaks on AVX512. + * + * When preparing, you provide a quantization multiplier. Values will be + * multiplied by this then rounded to an integer. + * For 16-bit neural networks, Jacob Devlin recommends 1024.0. + * For 8-bit, use 127 / largest absolute value. + * + * Note that quantization saturates. However, 16-bit does accumulation in + * 32-bit which can overflow if you use too big of a multiplier. + * + * The multiply routine expects an unquantization multiplier. + * This should be unquant_mult = 1.0 / (A_quant_mult * B_quant_mult). + * Where A_quant_mult is what you passed to PrepareA and B_quant_mult is what you + * passed to PrepareB. + * + * Feel free to multiply in a scaling factor to compute C = \lambda A * B by + * passing unquant_mult = \lambda / (A_quant_mult * B_quant_mult). + */ + +#include <cstdint> + +#include "types.h" +#include "sse2_gemm.h" +#include "ssse3_gemm.h" +#include "avx2_gemm.h" +#include "avx512_gemm.h" +#include "avx512vnni_gemm.h" + +/* Dispatch to functions based on runtime CPUID. This adds one call-by-variable to each call. */ + +namespace intgemm { + +void UnsupportedCPUError(); + +struct Unsupported_16bit { + static void Quantize(const float *, int16_t *, float, Index) { + UnsupportedCPUError(); + } + static void PrepareB(const float *, int16_t *, float, Index, Index) { + UnsupportedCPUError(); + } + static void PrepareBQuantizedTransposed(const int16_t *, int16_t *, Index, Index) { + UnsupportedCPUError(); + } + static void PrepareBTransposed(const float *, int16_t *, float, Index, Index) { + UnsupportedCPUError(); + } + static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) { + UnsupportedCPUError(); + } + template <typename Callback> + static void Multiply(const int16_t *, const int16_t *, Index, Index, Index, Callback) { + UnsupportedCPUError(); + } + constexpr static const char *const kName = "16-bit Unsupported"; +}; + +struct Unsupported_8bit { + static void Quantize(const float *, int8_t *, float, Index) { + UnsupportedCPUError(); + } + static void QuantizeU(const float *, uint8_t *, float, Index) { + UnsupportedCPUError(); + } + static void PrepareA(const float *, int8_t *, float, Index, Index) { + UnsupportedCPUError(); + } + static void PrepareBQuantizedTransposed(const int8_t *, int8_t *, Index, Index) { + UnsupportedCPUError(); + } + static void PrepareBTransposed(const float *, int8_t *, float, Index, Index) { + UnsupportedCPUError(); + } + static void PrepareB(const float *, int8_t *, float, Index, Index) { + UnsupportedCPUError(); + } + template<class Callback> + static void PrepareBias(const int8_t *, Index, Index, Callback) { + UnsupportedCPUError(); + } + static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) { + UnsupportedCPUError(); + } + template <typename Callback> + static void Multiply(const int8_t *, const int8_t *, Index, Index, Index, Callback) { + UnsupportedCPUError(); + } + template<class Callback> + static void Multiply8Shift(const uint8_t *, const int8_t *, Index, Index, Index, Callback) { + UnsupportedCPUError(); + } + + constexpr static const char *const kName = "8-bit Unsupported"; +}; + +#ifndef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +// These won't ever be called in this capacity, but it does let the code below compile. +namespace AVX512VNNI { +typedef Unsupported_8bit Kernels8; +} // namespace AVX512VNNI +#endif +#ifndef INTGEMM_COMPILER_SUPPORTS_AVX512BW +namespace AVX512BW { +typedef Unsupported_8bit Kernels8; +typedef Unsupported_16bit Kernels16; +} // namespace AVX512BW +#endif +#ifndef INTGEMM_COMPILER_SUPPORTS_AVX2 +namespace AVX2 { +typedef Unsupported_8bit Kernels8; +typedef Unsupported_16bit Kernels16; +} // namespace AVX2 +#endif + +CPUType GetCPUID(); + +/* Returns: + * axx512vnni if the CPU supports AVX512VNNI + * + * avx512bw if the CPU supports AVX512BW + * + * avx2 if the CPU supports AVX2 + * + * ssse3 if the CPU supports SSSE3 (this distinction from SSE2 matters for 8-bit) + * + * sse2 if the CPU supports SSE2 + * + * unsupported otherwise + */ +template <class T> T ChooseCPU(T avx512vnni, T avx512bw, T avx2, T ssse3, T sse2, T unsupported) { + const T ret[] = {unsupported, sse2, ssse3, avx2, avx512bw, avx512vnni}; + return ret[(int)GetCPUID()]; +} + +struct TileInfo { + const Index a_rows; + const Index a_cols; + const Index b_rows; + const Index b_cols; +}; + +/* + * 8-bit matrix multiplication + */ +struct Int8 { + using Integer = int8_t; + + // A's size must be a multiple of 1x64, B's size must be a multiple of 64x8. + static constexpr TileInfo tile_info{1, 64, 64, 8}; + + // Currently A is prepared by quantization but this could theoretically change. + // A's columns must be a multiple of 8. + // The number of rows is anything. + static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + // Multiply floats by quant_mult then convert to 8-bit integers with saturation. + static void (*Quantize)(const float *input, int8_t *output, float quant_mult, Index size); + + // Multiply floats by quant_mult then convert to 8-bit integers with saturation. + // A version that adds 127 to each number, making sure that all numbers are positive + static void (*QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size); + + // Warning: the output of PrepareB depends on the CPU. + // It will match the Multiply function on the same CPU though. + static void (*PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols); + + // Convert from a B that was already transposed (routine not provided) and + // quantized (e.g. with Quantize) to the CPU-dependent format used for + // Multiply. This is useful for storing a quantized model on disk then in a + // CPU-independent fashion. + static void (*PrepareBQuantizedTransposed)(const int8_t *input, int8_t *output, Index inner, Index B_untransposed_cols); + + // Convert from a B that was already transposed (routine not provided) to + // the CPU-dependent format used for Multiply. This is useful for storing + // a quantized model on disk then in a CPU-independent fashion. + static void (*PrepareBTransposed)(const float *input, int8_t *output, float quant_mul, Index inner, Index B_untransposed_cols); + + // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8. + static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end); + + // Multiply C = A * B, presuming A and B have been prepared. + template <typename Callback> + static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + MultiplyImpl<Callback>::run(A, B, A_rows, width, B_cols, callback); + } + + static const char *const kName; + +private: + template <typename Callback> + struct MultiplyImpl { + static void (*run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); + }; +}; + +template <typename Callback> +void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI::Kernels8>, OMPParallelWrap<Callback, AVX512BW::Kernels8>, OMPParallelWrap<Callback, AVX2::Kernels8>, OMPParallelWrap<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); + +/* + * 8-bit matrix multiplication with shifting A by 127 + */ +struct Int8Shift { + using Integer = int8_t; + + // A's size must be a multiple of 1x64, B's size must be a multiple of 64x8. + static constexpr TileInfo tile_info{1, 64, 64, 8}; + + // Identical to the Int8 Version, except it adds 127 to each number, making sure that all numbers are positive. + static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, reinterpret_cast<uint8_t *>(output), quant_mult, rows * cols); + } + + // Multiply floats by quant_mult then convert to 8-bit integers with saturation. + // A version that adds 127 to each number, making sure that all numbers are positive + static void (*QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size); + + // Warning: the output of PrepareB depends on the CPU. + // It will match the Multiply function on the same CPU though. + static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + Int8::PrepareB(input, output, quant_mult, rows, cols); + } + + // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8. + static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + Int8::SelectColumnsB(input, output, rows, cols_begin, cols_end); + } + + // A slightly faster version compared to the Int8 one (assuming a bias is used) because of better handling of the sign bit + // Multiply C = A * B + Bias, presuming A, B and Bias have all been prepared (for A, PrepareAnew should be used + template<class Callback> + static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + MultiplyImpl<Callback>::run((const uint8_t *)A, B, A_rows, width, B_cols, callback); + } + + // This function prepares the bias for the Multiply routine that does unsigned * signed multiplication. + // The function takes: + // a preparedB matrix, width, B_cols and + // the callback UnquantizeAndAddBiasAndWrite(unquant_mult, Bias_matrix, Bias_matrix) + // unquant_mult is computed by (-1)*(alpha)*(alpha)/(127.0f); + template<class Callback> + static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { + PrepareBiasImpl<Callback>::run(B, width, B_cols, callback); + } + + static const char *const kName; + +private: + template <typename Callback> + struct MultiplyImpl { + static void (*run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); + }; + + template <typename Callback> + struct PrepareBiasImpl { + static void (*run)(const int8_t *B, Index width, Index B_cols, Callback callback); + }; +}; + +template <class Callback> +void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU( + OMPParallelWrap8Shift<Callback, AVX512VNNI::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX512BW::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX2::Kernels8>, + OMPParallelWrap8Shift<Callback, SSSE3::Kernels8>, + Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>); + +template <class Callback> +void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI::Kernels8::PrepareBias<Callback>, AVX512BW::Kernels8::PrepareBias<Callback>, AVX2::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); + +/* + * 16-bit matrix multiplication + */ +struct Int16 { + using Integer = int16_t; + + // A's size must be a multiple of 1x32, B's size must be a multiple of 32x8. + static constexpr TileInfo tile_info{1, 32, 32, 8}; + + // Currently A is prepared by quantization but this could theoretically change. + // A's columns must be a multiple of 8. + // The number of rows is anything. + static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + // Multiply floats by quant_mult then convert to 16-bit integers with saturation. + // input + static void (*Quantize)(const float *input, int16_t *output, float quant_mult, Index size); + + // Warning: the output of PrepareB depends on the CPU. + // It will match the Multiply function on the same CPU though. + static void (*PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols); + + // Convert from a B that was already transposed (routine not provided) and + // quantized (e.g. with Quantize) to the CPU-dependent format used for + // Multiply. This is useful for storing a quantized model on disk then in a + // CPU-independent fashion. + static void (*PrepareBQuantizedTransposed)(const int16_t *input, int16_t *output, Index inner, Index B_untransposed_cols); + + // Convert from a B that was already transposed (routine not provided) to + // the CPU-dependent format used for Multiply. This is useful for storing + // a quantized model on disk then in a CPU-independent fashion. + static void (*PrepareBTransposed)(const float *input, int16_t *output, float quant_mul, Index inner, Index B_untransposed_cols); + + // Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8. + static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end); + + // Multiply C = A * B, presuming A and B have been prepared. + template <typename Callback> + static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + MultiplyImpl<Callback>::run(A, B, A_rows, width, B_cols, callback); + } + + static const char *const kName; + +private: + template <typename Callback> + struct MultiplyImpl { + static void (*run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback); + }; +}; + +template <typename Callback> +void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512BW::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512BW::Kernels16>, OMPParallelWrap<Callback, AVX2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, Unsupported_16bit::Multiply<Callback>); + +extern const CPUType kCPU; + +// Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned. +extern float (*MaxAbsolute)(const float *begin, const float *end); + +// Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default +extern MeanStd (*VectorMeanStd)(const float *begin, const float *end, bool); + +/* Returns the Mean and the Standard deviation of a vector. + * If "absolute" is set to true, it computes the mean and the standard deviation of the absolute values of the vector */ +static inline MeanStd GetVectorMeanStd(const float * begin, const float * end, bool absolute=false) { + return VectorMeanStd(begin, end, absolute); +} + + +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/intgemm_config.h.in b/third_party/intgemm/intgemm/intgemm_config.h.in new file mode 100644 index 0000000000..a2c8cbd347 --- /dev/null +++ b/third_party/intgemm/intgemm/intgemm_config.h.in @@ -0,0 +1,5 @@ +#pragma once + +#cmakedefine INTGEMM_COMPILER_SUPPORTS_AVX2 +#cmakedefine INTGEMM_COMPILER_SUPPORTS_AVX512BW +#cmakedefine INTGEMM_COMPILER_SUPPORTS_AVX512VNNI diff --git a/third_party/intgemm/intgemm/intrinsics.h b/third_party/intgemm/intgemm/intrinsics.h new file mode 100644 index 0000000000..9f370cd719 --- /dev/null +++ b/third_party/intgemm/intgemm/intrinsics.h @@ -0,0 +1,611 @@ +#pragma once + +#include "intgemm/intgemm_config.h" +#include "types.h" + +#include <tmmintrin.h> +#include <emmintrin.h> +#include <xmmintrin.h> +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +#include <immintrin.h> +#endif +#ifdef INTGEMM_WORMHOLE +#include <wasm_simd128.h> +#endif + +#include <cstdint> + +/* + * NOTE: Please keep intrinsics in alphabetical order. + */ +namespace intgemm { + +/* + * Define a bunch of intrinstics as overloaded functions so they work with + * templates. + */ +template <class Register> static inline Register load_ps(float const* from); +template <class Register> static inline Register loadu_ps(const float* mem_addr); +template <class Register> static inline Register set1_epi16(int16_t to); +template <class Register> static inline Register set1_epi32(int32_t to); +template <class Register> static inline Register set1_epi8(int8_t to); +template <class Register> static inline Register set1_pd(double to); +template <class Register> static inline Register set1_ps(float to); +template <class Register> static inline Register setzero_pd(); +template <class Register> static inline Register setzero_ps(); +template <class Register> static inline Register setzero_si(); + +/* + * + * SSE2 + * + */ +INTGEMM_SSSE3 static inline __m128i abs_epi8(__m128i arg) { + return _mm_abs_epi8(arg); +} +INTGEMM_SSE2 static inline __m128i add_epi8(__m128i a, __m128i b) { + return _mm_add_epi8(a, b); +} +INTGEMM_SSE2 static inline __m128i add_epi16(__m128i a, __m128i b) { + return _mm_add_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i add_epi32(__m128i first, __m128i second) { + return _mm_add_epi32(first, second); +} +INTGEMM_SSE2 static inline __m128i adds_epi16(__m128i first, __m128i second) { + return _mm_adds_epi16(first, second); +} +INTGEMM_SSE2 static inline __m128d add_pd(__m128d a, __m128d b) { + return _mm_add_pd(a, b); +} +INTGEMM_SSE2 static inline __m128 add_ps(__m128 a, __m128 b) { + return _mm_add_ps(a, b); +} +INTGEMM_SSE2 static inline __m128 and_ps(__m128 first, __m128 second) { + return _mm_and_ps(first, second); +} +INTGEMM_SSE2 static inline __m128 andnot_ps(__m128 a, __m128 b) { + return _mm_andnot_ps(a, b); +} +INTGEMM_SSE2 static inline __m128i and_si(__m128i a, __m128i b) { + return _mm_and_si128(a, b); +} +INTGEMM_SSE2 static inline __m128 cast_ps(__m128i a) { + return _mm_castsi128_ps(a); +} +INTGEMM_SSE2 static inline __m128 cvtepi32_ps(__m128i arg) { + return _mm_cvtepi32_ps(arg); +} +INTGEMM_SSE2 static inline __m128i cvtps_epi32(__m128 arg) { + return _mm_cvtps_epi32(arg); +} +INTGEMM_SSE2 static inline __m128i cvttps_epi32(__m128 a) { + return _mm_cvttps_epi32(a); +} +INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) { + return _mm_div_ps(a, b); +} +/* + * Missing i32gather_ps for SSE2 + */ +template <> INTGEMM_SSE2 inline __m128 load_ps<__m128>(const float* from) { + return _mm_load_ps(from); +} +template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) { + return _mm_loadu_ps(mem_addr); +} +INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) { +// https://bugzilla.mozilla.org/show_bug.cgi?id=1672160 +#ifdef INTGEMM_WORMHOLE + return wasm_v8x16_shuffle(first, second, 31, 0, 30, 2, 29, 4, 28, 6, 27, 8, 26, 10, 25, 12, 24, 2 /* PMADDWD */); +#else + return _mm_madd_epi16(first, second); +#endif +} +INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) { +// https://bugzilla.mozilla.org/show_bug.cgi?id=1672160 +#ifdef INTGEMM_WORMHOLE + return wasm_v8x16_shuffle(first, second, 31, 0, 30, 2, 29, 4, 28, 6, 27, 8, 26, 10, 25, 12, 24, 1 /* PMADDUBSW */); +#else + return _mm_maddubs_epi16(first, second); +#endif +} +/* + * Missing max_epi8 for SSE2 + */ +INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) { + return _mm_max_epi16(first, second); +} +INTGEMM_SSE2 static inline __m128d max_pd(__m128d first, __m128d second) { + return _mm_max_pd(first, second); +} +INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) { + return _mm_max_ps(first, second); +} +INTGEMM_SSE2 static inline __m128 min_ps(__m128 a, __m128 b) { + return _mm_min_ps(a, b); +} +INTGEMM_SSE2 static inline __m128i mul_epu32(__m128i a, __m128i b) { + return _mm_mul_epu32(a, b); +} +INTGEMM_SSE2 static inline __m128d mul_pd(__m128d a, __m128d b) { + return _mm_mul_pd(a, b); +} +INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) { + return _mm_mul_ps(a, b); +} +INTGEMM_SSE2 static inline __m128i mulhi_epi16(__m128i a, __m128i b) { + return _mm_mulhi_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i mullo_epi16(__m128i a, __m128i b) { + return _mm_mullo_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i or_si(__m128i a, __m128i b) { + return _mm_or_si128(a, b); +} +INTGEMM_SSE2 static inline __m128i packs_epi16(__m128i a, __m128i b) { + return _mm_packs_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i packs_epi32(__m128i a, __m128i b) { + return _mm_packs_epi32(a, b); +} +template <> INTGEMM_SSE2 inline __m128i set1_epi8<__m128i>(int8_t to) { + return _mm_set1_epi8(to); +} +template <> INTGEMM_SSE2 inline __m128i set1_epi16<__m128i>(int16_t to) { + return _mm_set1_epi16(to); +} +template <> INTGEMM_SSE2 inline __m128i set1_epi32<__m128i>(int32_t to) { + return _mm_set1_epi32(to); +} +template <> INTGEMM_SSE2 inline __m128d set1_pd<__m128d>(double to) { + return _mm_set1_pd(to); +} +template <> INTGEMM_SSE2 inline __m128 set1_ps<__m128>(float to) { + return _mm_set1_ps(to); +} +template <> INTGEMM_SSE2 inline __m128d setzero_pd<__m128d>() { + return _mm_setzero_pd(); +} +template <> INTGEMM_SSE2 inline __m128 setzero_ps<__m128>() { + return _mm_setzero_ps(); +} +template <> INTGEMM_SSE2 inline __m128i setzero_si<__m128i>() { + return _mm_setzero_si128(); +} +INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) { + return _mm_sign_epi8(first, second); +} +template <int imm8> INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a) { + return _mm_slli_epi16(a, imm8); +} +template <int imm8> INTGEMM_SSE2 static inline __m128i srai_epi16(__m128i a) { + return _mm_srai_epi16(a, imm8); +} +template <int imm8> INTGEMM_SSE2 static inline __m128i srai_epi32(__m128i a) { + return _mm_srai_epi32(a, imm8); +} +template <int imm8> INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a) { + return _mm_srli_epi16(a, imm8); +} +INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) { + _mm_storeu_ps(mem_addr, a); +} +INTGEMM_SSE2 static inline __m128d sub_pd(__m128d a, __m128d b) { + return _mm_sub_pd(a, b); +} +INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) { + return _mm_sub_ps(a, b); +} +INTGEMM_SSE2 static inline __m128i unpacklo_epi8(__m128i a, __m128i b) { + return _mm_unpacklo_epi8(a, b); +} +INTGEMM_SSE2 static inline __m128i unpackhi_epi8(__m128i a, __m128i b) { + return _mm_unpackhi_epi8(a, b); +} +INTGEMM_SSE2 static inline __m128i unpacklo_epi16(__m128i a, __m128i b) { + return _mm_unpacklo_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i unpackhi_epi16(__m128i a, __m128i b) { + return _mm_unpackhi_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i unpacklo_epi32(__m128i a, __m128i b) { + return _mm_unpacklo_epi32(a, b); +} +INTGEMM_SSE2 static inline __m128i unpackhi_epi32(__m128i a, __m128i b) { + return _mm_unpackhi_epi32(a, b); +} +INTGEMM_SSE2 static inline __m128i unpacklo_epi64(__m128i a, __m128i b) { + return _mm_unpacklo_epi64(a, b); +} +INTGEMM_SSE2 static inline __m128i unpackhi_epi64(__m128i a, __m128i b) { + return _mm_unpackhi_epi64(a, b); +} +INTGEMM_SSE2 static inline __m128i xor_si(__m128i a, __m128i b) { + return _mm_xor_si128(a, b); +} + +/* + * + * AVX2 + * + */ + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_AVX2 static inline __m256i abs_epi8(__m256i arg) { + return _mm256_abs_epi8(arg); +} +INTGEMM_AVX2 static inline __m256i add_epi8(__m256i a, __m256i b) { + return _mm256_add_epi8(a, b); +} +INTGEMM_AVX2 static inline __m256i add_epi16(__m256i a, __m256i b) { + return _mm256_add_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i add_epi32(__m256i first, __m256i second) { + return _mm256_add_epi32(first, second); +} +INTGEMM_AVX2 static inline __m256i adds_epi16(__m256i first, __m256i second) { + return _mm256_adds_epi16(first, second); +} +INTGEMM_AVX2 static inline __m256d add_pd(__m256d a, __m256d b) { + return _mm256_add_pd(a, b); +} +INTGEMM_AVX2 static inline __m256 add_ps(__m256 a, __m256 b) { + return _mm256_add_ps(a, b); +} +INTGEMM_AVX2 static inline __m256 and_ps(__m256 first, __m256 second) { + return _mm256_and_ps(first, second); +} +INTGEMM_AVX2 static inline __m256 andnot_ps(__m256 a, __m256 b) { + return _mm256_andnot_ps(a, b); +} +INTGEMM_AVX2 static inline __m256i and_si(__m256i a, __m256i b) { + return _mm256_and_si256(a, b); +} +INTGEMM_AVX2 static inline __m256 cast_ps(__m256i a) { + return _mm256_castsi256_ps(a); +} +INTGEMM_AVX2 static inline __m256 cvtepi32_ps(__m256i arg) { + return _mm256_cvtepi32_ps(arg); +} +INTGEMM_AVX2 static inline __m256i cvtps_epi32(__m256 arg) { + return _mm256_cvtps_epi32(arg); +} +INTGEMM_AVX2 static inline __m256i cvttps_epi32(__m256 a) { + return _mm256_cvttps_epi32(a); +} +INTGEMM_AVX2 static inline __m256 div_ps(__m256 a, __m256 b) { + return _mm256_div_ps(a, b); +} +template <unsigned Scale> +INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i vindex) { + return _mm256_i32gather_ps(base_addr, vindex, Scale); +} +template <> INTGEMM_AVX2 inline __m256 loadu_ps(const float* mem_addr) { + return _mm256_loadu_ps(mem_addr); +} +template <> INTGEMM_AVX2 inline __m256 load_ps<__m256>(const float* from) { + return _mm256_load_ps(from); +} +INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) { + return _mm256_madd_epi16(first, second); +} +INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second) { + return _mm256_maddubs_epi16(first, second); +} +INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) { + return _mm256_max_epi8(first, second); +} +INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) { + return _mm256_max_epi16(first, second); +} +INTGEMM_AVX2 static inline __m256d max_pd(__m256d first, __m256d second) { + return _mm256_max_pd(first, second); +} +INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) { + return _mm256_max_ps(first, second); +} +INTGEMM_AVX2 static inline __m256 min_ps(__m256 a, __m256 b) { + return _mm256_min_ps(a, b); +} +INTGEMM_AVX2 static inline __m256i mul_epu32(__m256i a, __m256i b) { + return _mm256_mul_epu32(a, b); +} +INTGEMM_AVX2 static inline __m256d mul_pd(__m256d a, __m256d b) { + return _mm256_mul_pd(a, b); +} +INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) { + return _mm256_mul_ps(a, b); +} +INTGEMM_AVX2 static inline __m256i mulhi_epi16(__m256i a, __m256i b) { + return _mm256_mulhi_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i mullo_epi16(__m256i a, __m256i b) { + return _mm256_mullo_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i or_si(__m256i a, __m256i b) { + return _mm256_or_si256(a, b); +} +INTGEMM_AVX2 static inline __m256i packs_epi16(__m256i a, __m256i b) { + return _mm256_packs_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i packs_epi32(__m256i a, __m256i b) { + return _mm256_packs_epi32(a, b); +} +template <> INTGEMM_AVX2 inline __m256i set1_epi8<__m256i>(int8_t to) { + return _mm256_set1_epi8(to); +} +template <> INTGEMM_AVX2 inline __m256i set1_epi16<__m256i>(int16_t to) { + return _mm256_set1_epi16(to); +} +template <> INTGEMM_AVX2 inline __m256i set1_epi32<__m256i>(int32_t to) { + return _mm256_set1_epi32(to); +} +template <> INTGEMM_AVX2 inline __m256d set1_pd<__m256d>(double to) { + return _mm256_set1_pd(to); +} +template <> INTGEMM_AVX2 inline __m256 set1_ps<__m256>(float to) { + return _mm256_set1_ps(to); +} +template <> INTGEMM_AVX2 inline __m256d setzero_pd<__m256d>() { + return _mm256_setzero_pd(); +} +template <> INTGEMM_AVX2 inline __m256 setzero_ps<__m256>() { + return _mm256_setzero_ps(); +} +template <> INTGEMM_AVX2 inline __m256i setzero_si<__m256i>() { + return _mm256_setzero_si256(); +} +INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) { + return _mm256_sign_epi8(first, second); +} +template <int imm8> INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a) { + return _mm256_slli_epi16(a, imm8); +} +template <int imm8> INTGEMM_AVX2 static inline __m256i srai_epi16(__m256i a) { + return _mm256_srai_epi16(a, imm8); +} +template <int imm8> INTGEMM_AVX2 static inline __m256i srai_epi32(__m256i a) { + return _mm256_srai_epi32(a, imm8); +} +template <int imm8> INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a) { + return _mm256_srli_epi16(a, imm8); +} +INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) { + _mm256_storeu_ps(mem_addr, a); +} +INTGEMM_AVX2 static inline __m256d sub_pd(__m256d a, __m256d b) { + return _mm256_sub_pd(a, b); +} +INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) { + return _mm256_sub_ps(a, b); +} +INTGEMM_AVX2 static inline __m256i unpacklo_epi8(__m256i a, __m256i b) { + return _mm256_unpacklo_epi8(a, b); +} +INTGEMM_AVX2 static inline __m256i unpackhi_epi8(__m256i a, __m256i b) { + return _mm256_unpackhi_epi8(a, b); +} +INTGEMM_AVX2 static inline __m256i unpacklo_epi16(__m256i a, __m256i b) { + return _mm256_unpacklo_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i unpackhi_epi16(__m256i a, __m256i b) { + return _mm256_unpackhi_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i unpacklo_epi32(__m256i a, __m256i b) { + return _mm256_unpacklo_epi32(a, b); +} +INTGEMM_AVX2 static inline __m256i unpackhi_epi32(__m256i a, __m256i b) { + return _mm256_unpackhi_epi32(a, b); +} +INTGEMM_AVX2 static inline __m256i unpacklo_epi64(__m256i a, __m256i b) { + return _mm256_unpacklo_epi64(a, b); +} +INTGEMM_AVX2 static inline __m256i unpackhi_epi64(__m256i a, __m256i b) { + return _mm256_unpackhi_epi64(a, b); +} +INTGEMM_AVX2 static inline __m256i xor_si(__m256i a, __m256i b) { + return _mm256_xor_si256(a, b); +} +#endif + +/* + * + * AVX512 + * + */ +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + +INTGEMM_AVX512BW static inline __m512i abs_epi8(__m512i arg) { + return _mm512_abs_epi8(arg); +} +INTGEMM_AVX512BW static inline __m512i add_epi8(__m512i a, __m512i b) { + return _mm512_add_epi8(a, b); +} +INTGEMM_AVX512BW static inline __m512i add_epi16(__m512i a, __m512i b) { + return _mm512_add_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i add_epi32(__m512i first, __m512i second) { + return _mm512_add_epi32(first, second); +} +INTGEMM_AVX512BW static inline __m512i adds_epi16(__m512i first, __m512i second) { + return _mm512_adds_epi16(first, second); +} +INTGEMM_AVX512BW static inline __m512d add_pd(__m512d a, __m512d b) { + return _mm512_add_pd(a, b); +} +INTGEMM_AVX512BW static inline __m512 add_ps(__m512 a, __m512 b) { + return _mm512_add_ps(a, b); +} +INTGEMM_AVX512DQ static inline __m512 and_ps(__m512 first, __m512 second) { + return _mm512_and_ps(first, second); +} +INTGEMM_AVX512DQ static inline __m512 andnot_ps(__m512 a, __m512 b) { + return _mm512_andnot_ps(a, b); +} +INTGEMM_AVX512BW static inline __m512i and_si(__m512i a, __m512i b) { + return _mm512_and_si512(a, b); +} +INTGEMM_AVX512F static inline __m512 cast_ps(__m512i a) { + return _mm512_castsi512_ps(a); +} +INTGEMM_AVX512BW static inline __m512 cvtepi32_ps(__m512i arg) { + return _mm512_cvtepi32_ps(arg); +} +INTGEMM_AVX512BW static inline __m512i cvtps_epi32(__m512 arg) { + return _mm512_cvtps_epi32(arg); +} +INTGEMM_AVX512BW static inline __m512i cvttps_epi32(__m512 a) { + return _mm512_cvttps_epi32(a); +} +INTGEMM_AVX512BW static inline __m512 div_ps(__m512 a, __m512 b) { + return _mm512_div_ps(a, b); +} +template <unsigned Scale> +INTGEMM_AVX512BW static inline __m512 i32gather_ps(float const *base_addr, __m512i vindex) { + return _mm512_i32gather_ps(vindex, base_addr, Scale); +} +template <> INTGEMM_AVX512BW inline __m512 loadu_ps(const float* mem_addr) { + return _mm512_loadu_ps(mem_addr); +} +INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second) { + return _mm512_madd_epi16(first, second); +} +INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i second) { + return _mm512_maddubs_epi16(first, second); +} +INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) { + return _mm512_max_epi8(first, second); +} +INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) { + return _mm512_max_epi16(first, second); +} +INTGEMM_AVX512BW static inline __m512d max_pd(__m512d first, __m512d second) { + return _mm512_max_pd(first, second); +} +INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) { + return _mm512_max_ps(first, second); +} +INTGEMM_AVX512BW static inline __m512 min_ps(__m512 a, __m512 b) { + return _mm512_min_ps(a, b); +} +INTGEMM_AVX512BW static inline __m512i mul_epu32(__m512i a, __m512i b) { + return _mm512_mul_epu32(a, b); +} +INTGEMM_AVX512BW static inline __m512d mul_pd(__m512d a, __m512d b) { + return _mm512_mul_pd(a, b); +} +INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) { + return _mm512_mul_ps(a, b); +} +INTGEMM_AVX512BW static inline __m512i mulhi_epi16(__m512i a, __m512i b) { + return _mm512_mulhi_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i mullo_epi16(__m512i a, __m512i b) { + return _mm512_mullo_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i or_si(__m512i a, __m512i b) { + return _mm512_or_si512(a, b); +} +INTGEMM_AVX512BW static inline __m512i packs_epi16(__m512i a, __m512i b) { + return _mm512_packs_epi16(a, b); +} +/* g++ (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609 has a bug: + * /usr/lib/gcc/x86_64-linux-gnu/5/include/avx512bwintrin.h is missing + * _mm512_packs_epi32 when compiled with debugging. + */ +#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && (__GNUC_MINOR__ == 4) +INTGEMM_AVX512BW static inline __attribute__ ((__gnu_inline__, __always_inline__, __artificial__)) __m512i packs_epi32(__m512i a, __m512i b) { + return reinterpret_cast<__m512i>(__builtin_ia32_packssdw512_mask( + reinterpret_cast<__v16si>(a), + reinterpret_cast<__v16si>(b), + reinterpret_cast<__v32hi>(_mm512_setzero_si512()), + 0xffffffff)); +} +#else +INTGEMM_AVX512BW static inline __m512i packs_epi32(__m512i a, __m512i b) { + return _mm512_packs_epi32(a, b); +} +#endif +template <> inline INTGEMM_AVX512BW __m512i set1_epi8<__m512i>(int8_t to) { + return _mm512_set1_epi8(to); +} +template <> inline INTGEMM_AVX512BW __m512i set1_epi16<__m512i>(int16_t to) { + return _mm512_set1_epi16(to); +} +template <> inline INTGEMM_AVX512BW __m512i set1_epi32<__m512i>(int32_t to) { + return _mm512_set1_epi32(to); +} +template <> inline INTGEMM_AVX512BW __m512d set1_pd<__m512d>(double to) { + return _mm512_set1_pd(to); +} +template <> inline INTGEMM_AVX512BW __m512 set1_ps<__m512>(float to) { + return _mm512_set1_ps(to); +} +template <> INTGEMM_AVX512BW inline __m512d setzero_pd<__m512d>() { + return _mm512_setzero_pd(); +} +template <> INTGEMM_AVX512BW inline __m512 setzero_ps<__m512>() { + return _mm512_setzero_ps(); +} +template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() { + return _mm512_setzero_si512(); +} +template <> INTGEMM_AVX512BW inline __m512 load_ps<__m512>(const float* from) { + return _mm512_load_ps(from); +} +/* + * Missing sign_epi8 + */ +template <int imm8> INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a) { + return _mm512_slli_epi16(a, imm8); +} +template <int imm8> INTGEMM_AVX512BW static inline __m512i srai_epi16(__m512i a) { + return _mm512_srai_epi16(a, imm8); +} +template <int imm8> INTGEMM_AVX512BW static inline __m512i srai_epi32(__m512i a) { + return _mm512_srai_epi32(a, imm8); +} +template <int imm8> INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a) { + return _mm512_srli_epi16(a, imm8); +} +INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) { + _mm512_storeu_ps(mem_addr, a); +} +INTGEMM_AVX512BW static inline __m512d sub_pd(__m512d a, __m512d b) { + return _mm512_sub_pd(a, b); +} +INTGEMM_AVX512BW static inline __m512 sub_ps(__m512 a, __m512 b) { + return _mm512_sub_ps(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpacklo_epi8(__m512i a, __m512i b) { + return _mm512_unpacklo_epi8(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpackhi_epi8(__m512i a, __m512i b) { + return _mm512_unpackhi_epi8(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpacklo_epi16(__m512i a, __m512i b) { + return _mm512_unpacklo_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpackhi_epi16(__m512i a, __m512i b) { + return _mm512_unpackhi_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpacklo_epi32(__m512i a, __m512i b) { + return _mm512_unpacklo_epi32(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpackhi_epi32(__m512i a, __m512i b) { + return _mm512_unpackhi_epi32(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpacklo_epi64(__m512i a, __m512i b) { + return _mm512_unpacklo_epi64(a, b); +} +INTGEMM_AVX512BW static inline __m512i unpackhi_epi64(__m512i a, __m512i b) { + return _mm512_unpackhi_epi64(a, b); +} +INTGEMM_AVX512BW static inline __m512i xor_si(__m512i a, __m512i b) { + return _mm512_xor_si512(a, b); +} + +#endif + +} diff --git a/third_party/intgemm/intgemm/kernels.h b/third_party/intgemm/intgemm/kernels.h new file mode 100644 index 0000000000..57036f4d31 --- /dev/null +++ b/third_party/intgemm/intgemm/kernels.h @@ -0,0 +1,26 @@ +#pragma once + +#include "intgemm/intgemm_config.h" +#include "intrinsics.h" +#include "types.h" +#include "utils.h" +#include "vec_traits.h" + +#include <cstdlib> + +#define KERNELS_THIS_IS_SSE2 +#include "kernels/implementations.inl" +#undef KERNELS_THIS_IS_SSE2 + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +#define KERNELS_THIS_IS_AVX2 +#include "kernels/implementations.inl" +#undef KERNELS_THIS_IS_AVX2 +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +#define KERNELS_THIS_IS_AVX512BW +#include "kernels/implementations.inl" +#undef KERNELS_THIS_IS_AVX512BW +#endif + diff --git a/third_party/intgemm/intgemm/kernels/implementations.inl b/third_party/intgemm/intgemm/kernels/implementations.inl new file mode 100644 index 0000000000..4f1b39fb39 --- /dev/null +++ b/third_party/intgemm/intgemm/kernels/implementations.inl @@ -0,0 +1,456 @@ +/* This file is included multiple times, once for each backend instruction set. */ + +#if defined(KERNELS_THIS_IS_SSE2) + #define CPU_NAME SSE2 + #define CPU_ATTR INTGEMM_SSE2 +#elif defined(KERNELS_THIS_IS_AVX2) + #define CPU_NAME AVX2 + #define CPU_ATTR INTGEMM_AVX2 +#elif defined(KERNELS_THIS_IS_AVX512BW) + #define CPU_NAME AVX512BW + #define CPU_ATTR INTGEMM_AVX512BW +#else + #error "Only SSE2, AVX2 and AVX512BW are supported" +#endif + +#define vi vector_t<CPUType::CPU_NAME, int> +#define vf vector_t<CPUType::CPU_NAME, float> +#define vd vector_t<CPUType::CPU_NAME, double> + +/* + * Kernels implementations.... + */ +namespace intgemm { +namespace kernels { + +/* + * Write + */ +CPU_ATTR static inline void write(vi input, int8_t* output, Index offset) { + *reinterpret_cast<vi*>(output + offset) = input; +} + +CPU_ATTR static inline void write(vi input, int16_t* output, Index offset) { + *reinterpret_cast<vi*>(output + offset) = input; +} + +CPU_ATTR static inline void write(vi input, int* output, Index offset) { + *reinterpret_cast<vi*>(output + offset) = input; +} + +CPU_ATTR static inline void write(vf input, float* output, Index offset) { + *reinterpret_cast<vf*>(output + offset) = input; +} + +CPU_ATTR static inline void write(vd input, double* output, Index offset) { + *reinterpret_cast<vd*>(output + offset) = input; +} + +/* + * Quantize + */ +CPU_ATTR static inline vi quantize(vf input, vf quant_mult) { + return cvtps_epi32(mul_ps(input, quant_mult)); +} + +/* + * Unquantize + */ +CPU_ATTR static inline vf unquantize(vi input, vf unquant_mult) { + return mul_ps(cvtepi32_ps(input), unquant_mult); +} + +/* + * Add a bias term + */ +CPU_ATTR static inline vi add_bias(vi input, const int8_t* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); + return add_epi8(input, bias_term); +} + +CPU_ATTR static inline vi add_bias(vi input, const int16_t* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); + return add_epi16(input, bias_term); +} + +CPU_ATTR static inline vi add_bias(vi input, const int* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); + return add_epi32(input, bias_term); +} + +CPU_ATTR static inline vf add_bias(vf input, const float* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vf*>(bias_addr + bias_offset); + return add_ps(input, bias_term); +} + +CPU_ATTR static inline vd add_bias(vd input, const double* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vd*>(bias_addr + bias_offset); + return add_pd(input, bias_term); +} + +/* + * ReLU + */ +template <typename Type> +CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> relu(vector_t<CPUType::CPU_NAME, Type> input); + +template <> +CPU_ATTR inline vi relu<int8_t>(vi input) { + static const auto vconst_zero = set1_epi8<vi>(0); +#if defined(KERNELS_THIS_IS_SSE2) + return and_si(input, _mm_cmplt_epi8(vconst_zero, input)); +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_max_epi8(input, vconst_zero); +#else + return _mm512_max_epi8(input, vconst_zero); +#endif +} + +template <> +CPU_ATTR inline vi relu<int16_t>(vi input) { + static const auto vconst_zero = set1_epi16<vi>(0); + return max_epi16(input, vconst_zero); +} + +template <> +CPU_ATTR inline vi relu<int>(vi input) { + static const auto vconst_zero = set1_epi32<vi>(0); +#if defined(KERNELS_THIS_IS_SSE2) + return and_si(input, _mm_cmplt_epi32(vconst_zero, input)); +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_max_epi32(input, vconst_zero); +#else + return _mm512_max_epi32(input, vconst_zero); +#endif +} + +template <> +CPU_ATTR inline vf relu<float>(vf input) { + static const auto vconst_zero = setzero_ps<vf>(); + return max_ps(input, vconst_zero); +} + +template <> +CPU_ATTR inline vd relu<double>(vd input) { + static const auto vconst_zero = setzero_pd<vd>(); + return max_pd(input, vconst_zero); +} + +/* + * Multiply (elemwise) + */ +template <typename Type> +CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b); + +template <> +CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) { + auto even = mullo_epi16(a, b); + auto odd = mullo_epi16(srli_epi16<8>(a), srli_epi16<8>(b)); + return or_si(slli_epi16<8>(odd), srli_epi16<8>(slli_epi16<8>(even))); +} + +template <> +CPU_ATTR inline vi multiply<int16_t>(vi a, vi b) { + return mullo_epi16(a, b); +} + +template <> +CPU_ATTR inline vi multiply<int>(vi a, vi b) { +#if defined(KERNELS_THIS_IS_SSE2) + auto even = mul_epu32(a, b); + auto odd = mul_epu32(_mm_srli_si128(a, 4), _mm_srli_si128(b, 4)); + return unpacklo_epi32(_mm_shuffle_epi32(even, 0x8 /* = 0 0 2 0 */), _mm_shuffle_epi32(odd, 0x8 /* = 0 0 2 0 */)); +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_mullo_epi32(a, b); +#else + return _mm512_mullo_epi32(a, b); +#endif +} + +template <> +CPU_ATTR inline vf multiply<float>(vf a, vf b) { + return mul_ps(a, b); +} + +template <> +CPU_ATTR inline vd multiply<double>(vd a, vd b) { + return mul_pd(a, b); +} + +/* + * Downcast + */ +CPU_ATTR static inline vi downcast32to8(vi input1, vi input2, vi input3, vi input4) { + auto result = packs_epi16(packs_epi32(input1, input2), packs_epi32(input3, input4)); + +#if defined(KERNELS_THIS_IS_SSE2) + return result; +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_shuffle_epi32(_mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */), 0xd8 /* = 0 2 1 3 */); +#else + static const auto permutation_indices = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + return _mm512_castps_si512(_mm512_permutexvar_ps(permutation_indices, _mm512_castsi512_ps(result))); +#endif +} + +CPU_ATTR static inline vi downcast32to16(vi input1, vi input2) { + auto result = packs_epi32(input1, input2); + +#if defined(KERNELS_THIS_IS_SSE2) + return result; +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */); +#else + static const auto permutation_indices = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(result))); +#endif +} + +CPU_ATTR static inline vi downcast16to8(vi input1, vi input2) { + auto result = packs_epi16(input1, input2); + +#if defined(KERNELS_THIS_IS_SSE2) + return result; +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */); +#else + static const auto permutation_indices = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(result))); +#endif +} + +/* + * Upcast + */ +CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int16_t> upcast8to16(vi input) { + static const auto vzero = set1_epi8<vi>(0); + +#if defined(KERNELS_THIS_IS_SSE2) + auto higher_byte = _mm_cmpgt_epi8(vzero, input); +#elif defined(KERNELS_THIS_IS_AVX2) + input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */); + auto higher_byte = _mm256_cmpgt_epi8(vzero, input); +#else + static const auto vmax_negative = set1_epi8<vi>(-1 /* 0xff */); + static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + + input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input))); + auto negatives = _mm512_cmp_epi8_mask(input, vzero, 1 /* _MM_CMPINT_LT */); + auto higher_byte = _mm512_mask_blend_epi8(negatives, vzero, vmax_negative); +#endif + + return { + unpacklo_epi8(input, higher_byte), + unpackhi_epi8(input, higher_byte), + }; +} + +CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int> upcast16to32(vi input) { + static const auto vzero = set1_epi16<vi>(0); + +#if defined(KERNELS_THIS_IS_SSE2) + auto higher_byte = _mm_cmpgt_epi16(vzero, input); +#elif defined(KERNELS_THIS_IS_AVX2) + input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */); + auto higher_byte = _mm256_cmpgt_epi16(vzero, input); +#else + static const auto vmax_negative = set1_epi16<vi>(-1 /* 0xffff */); + static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + + input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input))); + auto negatives = _mm512_cmp_epi16_mask(input, vzero, 1 /* _MM_CMPINT_LT */); + auto higher_byte = _mm512_mask_blend_epi16(negatives, vzero, vmax_negative); +#endif + + return { + unpacklo_epi16(input, higher_byte), + unpackhi_epi16(input, higher_byte), + }; +} + +CPU_ATTR static inline qvector_t<CPUType::CPU_NAME, int> upcast8to32(vi input) { + auto result16 = upcast8to16(input); + auto result32a = upcast16to32(result16.first); + auto result32b = upcast16to32(result16.second); + + return { + result32a.first, + result32a.second, + result32b.first, + result32b.second, + }; +} + +/* + * Rescale int32 + */ +CPU_ATTR static inline vi rescale(vi input, vf scale) { + return cvtps_epi32(mul_ps(cvtepi32_ps(input), scale)); +} + +/* + * Bitwise not + */ +CPU_ATTR static inline vi bitwise_not(vi v) { + return xor_si(v, set1_epi32<vi>(0xffffffff)); +} + +/* + * Floor + */ +CPU_ATTR static inline vf floor(vf input) { +#if defined(KERNELS_THIS_IS_SSE2) + static const auto vconst_zero = setzero_ps<vf>(); + static const auto vconst_one = set1_ps<vf>(1.f); + + auto result = cvtepi32_ps(cvttps_epi32(input)); + auto negatives = _mm_cmplt_ps(input, vconst_zero); + auto nonintegers = _mm_cmpneq_ps(input, result); + + return sub_ps(result, and_ps(vconst_one, and_ps(negatives, nonintegers))); +#elif defined(KERNELS_THIS_IS_AVX2) + return _mm256_floor_ps(input); +#else + // TODO: It should work but compiler throw the error "incorrect rounding operand" + // return _mm512_roundscale_round_ps(input, 0, _MM_FROUND_FLOOR); + + static const auto vconst_zero = setzero_ps<vf>(); + static const auto vconst_one = set1_ps<vf>(1.f); + + auto result = cvtepi32_ps(cvttps_epi32(input)); + auto negatives = _mm512_cmp_ps_mask(input, vconst_zero, _CMP_LT_OQ); + auto nonintegers = _mm512_cmp_ps_mask(input, result, _CMP_NEQ_OQ); + + return _mm512_mask_blend_ps(_mm512_kand(negatives, nonintegers), result, sub_ps(result, vconst_one)); +#endif +} + +/* + * Calculate approximation of e^x using Taylor series and lookup table + */ +#if defined(KERNELS_THIS_IS_SSE2) +CPU_ATTR static inline vf exp_approx_taylor(vf) { + std::abort(); +} +#else +CPU_ATTR static inline vf exp_approx_taylor(vf x) { + static constexpr int EXP_MIN = -20; + static constexpr int EXP_MAX = 20; + static constexpr float EXP_LOOKUP[EXP_MAX - EXP_MIN + 1] = { + expif(-20), expif(-19), expif(-18), expif(-17), expif(-16), expif(-15), + expif(-14), expif(-13), expif(-12), expif(-11), expif(-10), expif(-9), + expif(-8), expif(-7), expif(-6), expif(-5), expif(-4), expif(-3), expif(-2), + expif(-1), expif(0), expif(1), expif(2), expif(3), expif(4), expif(5), + expif(6), expif(7), expif(8), expif(9), expif(10), expif(11), expif(12), + expif(13), expif(14), expif(15), expif(16), expif(17), expif(18), expif(19), + expif(20), + }; + + static const vf dividers[] = { + set1_ps<vf>(1.f / factorial(7)), + set1_ps<vf>(1.f / factorial(6)), + set1_ps<vf>(1.f / factorial(5)), + set1_ps<vf>(1.f / factorial(4)), + set1_ps<vf>(1.f / factorial(3)), + set1_ps<vf>(1.f / factorial(2)), + set1_ps<vf>(1.f / factorial(1)), + }; + static const auto const_one = set1_ps<vf>(1.f); + static const auto const_min_x = set1_ps<vf>(EXP_MIN); + static const auto const_max_x = set1_ps<vf>(EXP_MAX); + + x = max_ps(x, const_min_x); + x = min_ps(x, const_max_x); + + auto a = floor(x); + auto xa = sub_ps(x, a); + + auto result = mul_ps(dividers[0], xa); + + result = add_ps(result, dividers[1]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[2]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[3]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[4]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[5]); + result = mul_ps(result, xa); + result = add_ps(result, dividers[6]); + result = mul_ps(result, xa); + + result = add_ps(result, const_one); + + auto ea = i32gather_ps<4>(EXP_LOOKUP + EXP_MAX, cvtps_epi32(a)); + return mul_ps(ea, result); +} +#endif + +/* + * Sigmoid + */ +CPU_ATTR static inline vf sigmoid(vf +#ifndef KERNELS_THIS_IS_SSE2 + input +#endif + ) { +#if defined(KERNELS_THIS_IS_SSE2) + std::abort(); // TODO: missing exp_approx_taylor for SSE2 +#elif defined(KERNELS_THIS_IS_AVX2) + static const auto vconst_zero = setzero_ps<vf>(); + static const auto vconst_one = set1_ps<vf>(1.f); + + auto x = input; + auto minus_x = sub_ps(vconst_zero, x); + auto e_x = exp_approx_taylor(x); + auto e_minus_x = exp_approx_taylor(minus_x); + + auto sigmoid_case1 = _mm256_rcp_ps(add_ps(vconst_one, e_minus_x)); + auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(vconst_one, e_x))); + + auto nonnegative_x_mask = _mm256_cmp_ps(vconst_zero, x, _CMP_LT_OS); + return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask); +#else + static const auto vconst_zero = setzero_ps<vf>(); + static const auto vconst_one = set1_ps<vf>(1.f); + + auto x = input; + auto minus_x = sub_ps(vconst_zero, x); + auto e_x = exp_approx_taylor(x); + auto e_minus_x = exp_approx_taylor(minus_x); + + auto sigmoid_case1 = _mm512_rcp14_ps(add_ps(vconst_one, e_minus_x)); + auto sigmoid_case2 = mul_ps(e_x, _mm512_rcp14_ps(add_ps(vconst_one, e_x))); + + auto nonnegative_x_mask = _mm512_cmp_ps_mask(vconst_zero, x, _CMP_LT_OS); + return _mm512_mask_blend_ps(nonnegative_x_mask, sigmoid_case1, sigmoid_case2); +#endif +} + +/* + * Tanh + */ +#if defined(KERNELS_THIS_IS_SSE2) +CPU_ATTR static inline vf tanh(vf) { + std::abort(); // TODO: missing exp_approx_taylor for SSE2 +} +#else +CPU_ATTR static inline vf tanh(vf input) { + const static auto vconst_zero = setzero_ps<vf>(); + + auto e_x = exp_approx_taylor(input); + auto e_minus_x = exp_approx_taylor(sub_ps(vconst_zero, input)); + + return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); +} +#endif + +} +} + +#undef CPU_NAME +#undef CPU_ATTR +#undef vi +#undef vf +#undef vd diff --git a/third_party/intgemm/intgemm/multiply.h b/third_party/intgemm/intgemm/multiply.h new file mode 100644 index 0000000000..8d411f33da --- /dev/null +++ b/third_party/intgemm/intgemm/multiply.h @@ -0,0 +1,626 @@ +#pragma once + +#include "intgemm/intgemm_config.h" +#include "interleave.h" +#include "intrinsics.h" +#include "vec_traits.h" +#include "callbacks.h" + +namespace intgemm { + +INTGEMM_SSE2 static inline dvector_t<CPUType::SSE2, int> PermuteSummer(__m128i pack0123, __m128i pack4567) { + // No op for 128 bits: already reduced fully. + return { pack0123, pack4567 }; +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_AVX2 static inline __m256i PermuteSummer(__m256i pack0123, __m256i pack4567) { + // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f + __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21); + // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s + __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0); + return _mm256_add_epi32(rev, blended); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_AVX512BW static inline __m256i PermuteSummer(__m512i pack0123, __m512i pack4567) { + // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567, 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567] + __m512i mix0 = _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6)); + // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567, 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567] + __m512i mix1 = _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2)); + __m512i added = _mm512_add_epi32(mix0, mix1); + // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7. + // Fold register over itself. + return _mm256_add_epi32(_mm512_castsi512_si256(added), _mm512_extracti64x4_epi64(added, 1)); +} +#endif + +#ifdef _MSC_VER +#define INTGEMM_OMP_FOR __pragma(omp for) +#define INTGEMM_OMP_PARALLEL __pragma(omp parallel) +#else +#define INTGEMM_OMP_FOR _Pragma("omp for") +#define INTGEMM_OMP_PARALLEL _Pragma("omp parallel") +#endif + +// Quantize function used for SSSE3 and AVX2. +// Separate function for thread to work around gcc 7 bug that doesn't imbue +// target attributes across #pragma omp parallel. +#define INTGEMM_QUANTIZE_THREAD(target) \ +target static void QuantizeThread(const float *input, int8_t *output, float quant_mult, std::size_t count) { \ + FRegister q = set1_ps<FRegister>(quant_mult); \ + INTGEMM_OMP_FOR \ + for (std::size_t i = 0; i < count; i += sizeof(Register)) { \ + *reinterpret_cast<Register*>(output + i) = QuantizeTile8::Consecutive(q, input + i); \ + } \ +} + +#define INTGEMM_QUANTIZE(target) \ +target static void Quantize(const float *const input, int8_t *const output, float quant_mult, Index size) { \ + assert(reinterpret_cast<uintptr_t>(input) % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(output) % sizeof(Register) == 0); \ + const std::size_t kBatch = sizeof(Register); \ + const std::size_t fast_end = size & ~(kBatch - 1); \ + INTGEMM_OMP_PARALLEL \ + { \ + QuantizeThread(input, output, quant_mult, fast_end); \ + } \ + std::size_t overhang = size & (kBatch - 1); \ + if (!overhang) return; \ + FRegister q = set1_ps<FRegister>(quant_mult); \ + /* Each does size(Register) / 32 == kBatch / 4 floats at a time. + * If we're allowed to read one of them, then we can read the whole register. */ \ + const float *inputs[4]; \ + std::size_t i; \ + for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) { \ + inputs[i] = &input[fast_end + i * (kBatch / 4)]; \ + } \ + /* These will be clipped off. */ \ + for (; i < 4; ++i) { \ + inputs[i] = &input[fast_end]; \ + } \ + Register result = QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]); \ + std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang); \ +} + +/* Take 4 registers with 32-bit values to be horizontally added. Reduce them + * to one register with 32-bit values in the pattern 1 2 3 4 1 2 3 4, leaving + * the final addition (which crosses 128-bit lanes) to the caller. + */ +#define INTGEMM_PACK0123(target, Register) \ +target inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) { \ + Interleave32(sum0, sum1); \ + Register pack01 = add_epi32(sum0, sum1); \ + Interleave32(sum2, sum3); \ + Register pack23 = add_epi32(sum2, sum3); \ + Interleave64(pack01, pack23); \ + return add_epi32(pack01, pack23); \ +} \ + +INTGEMM_PACK0123(INTGEMM_SSE2, __m128i) +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_PACK0123(INTGEMM_AVX2, __m256i) +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ +INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i) +#endif + +template <typename Callback> +INTGEMM_SSE2 static inline void RunCallback(Callback& callback_impl, dvector_t<CPUType::SSE2, int> total, Index row_idx, Index col_idx, Index rows, Index cols) { + callback_impl.Run(total.first, callbacks::OutputBufferInfo(row_idx, col_idx, rows, cols)); + callback_impl.Run(total.second, callbacks::OutputBufferInfo(row_idx, col_idx + 4, rows, cols)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template <typename Callback> +INTGEMM_AVX2 static inline void RunCallback(Callback& callback_impl, vector_t<CPUType::AVX2, int> total, Index row_idx, Index col_idx, Index rows, Index cols) { + callback_impl.Run(total, callbacks::OutputBufferInfo(row_idx, col_idx, rows, cols)); +} +#endif + +// 16-bit multiplier for INTGEMM_SSE2, INTGEMM_AVX2, and AVX512. +// C = A * B * unquant_mult +// +// This has been substantially revised from Jacob Devlin's SSE code which is: +// Copyright (c) 2017 Microsoft Corporation + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// A is a row-major quantized matrix (from PrepareA) +// B is a rearranged quantized matrix (from PrepareB) +// C is output in row-major form. +// +// All of A, B, and C must be in aligned to a multiple of the register size: +// INTGEMM_SSE2: 16 bytes +// INTGEMM_AVX2: 32 bytes +// AVX512: 64 bytes. +// +// A_rows can be anything non-negative. +// width must be a multiple of the register size. +// B_cols must be a multiple of 8. +// Multiply16 +#define INTGEMM_MULTIPLY16(Register, target, cpu_type) \ +template <typename Callback> target static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ + assert(width % (sizeof(Register) / sizeof(int16_t)) == 0); \ + assert(B_cols % 8 == 0); \ + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ + const Index simd_width = width / (sizeof(Register) / sizeof(int16_t)); \ + auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ + INTGEMM_OMP_FOR \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ + /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ + const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \ + /* These will be packed 32-bit integers containing sums for each row of B multiplied by the row of A. \ + Iterate over shared (inner) dimension.*/ \ + Index k = 0; \ + Register a = *(A_row + k); \ + Register sum0 = madd_epi16(a, *(B0_col + k * 8)); \ + Register sum1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \ + Register sum2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \ + Register sum3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \ + Register sum4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \ + Register sum5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \ + Register sum6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \ + Register sum7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \ + for (k = 1; k < simd_width; ++k) { \ + a = *(A_row + k); \ + /* Multiply 16-bit, horizontally add to packed 32-bit integers.*/ \ + Register mult0 = madd_epi16(a, *(B0_col + k * 8)); \ + Register mult1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \ + Register mult2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \ + Register mult3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \ + Register mult4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \ + Register mult5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \ + Register mult6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \ + Register mult7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Sum packed 32-bit integers with danger of overflow. TODO: accumulate in 64-bit every so often.*/ \ + sum0 = add_epi32(sum0, mult0); \ + sum1 = add_epi32(sum1, mult1); \ + sum2 = add_epi32(sum2, mult2); \ + sum3 = add_epi32(sum3, mult3); \ + sum4 = add_epi32(sum4, mult4); \ + sum5 = add_epi32(sum5, mult5); \ + sum6 = add_epi32(sum6, mult6); \ + sum7 = add_epi32(sum7, mult7); \ + } \ + /* Reduce sums within 128-bit lanes.*/ \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + /*The specific implementation may need to reduce further.*/ \ + auto total = PermuteSummer(pack0123, pack4567); \ + RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ + } \ + } \ +} \ + +//An int8_prepbias version of the above code, using the add 127 technique +#define INTGEMM_PREPAREBIASFOR8(Register, target, cpu_type) \ + template <class Callback> target static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { \ + assert(width % (sizeof(Register) / sizeof(int8_t)) == 0); \ + assert(B_cols % 8 == 0); \ + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ + const Index simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ + auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ + const Register a = set1_epi8<Register>(1); \ + INTGEMM_OMP_FOR \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ + /*const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width);*/ \ + /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ + Iterate over shared (inner) dimension.*/ \ + Index k = 0; \ + Register sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \ + Register ones = set1_epi16<Register>(1); \ + sum0 = madd_epi16(sum0, ones); \ + sum1 = madd_epi16(sum1, ones); \ + sum2 = madd_epi16(sum2, ones); \ + sum3 = madd_epi16(sum3, ones); \ + sum4 = madd_epi16(sum4, ones); \ + sum5 = madd_epi16(sum5, ones); \ + sum6 = madd_epi16(sum6, ones); \ + sum7 = madd_epi16(sum7, ones); \ + for (k = 1; k < simd_width; ++k) { \ + /*Register a = *(A_row + k);*/ \ + /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \ + Register mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add.*/ \ + mult0 = madd_epi16(mult0, ones); \ + mult1 = madd_epi16(mult1, ones); \ + mult2 = madd_epi16(mult2, ones); \ + mult3 = madd_epi16(mult3, ones); \ + mult4 = madd_epi16(mult4, ones); \ + mult5 = madd_epi16(mult5, ones); \ + mult6 = madd_epi16(mult6, ones); \ + mult7 = madd_epi16(mult7, ones); \ + /*Add in 32bit*/ \ + sum0 = add_epi32(sum0, mult0); \ + sum1 = add_epi32(sum1, mult1); \ + sum2 = add_epi32(sum2, mult2); \ + sum3 = add_epi32(sum3, mult3); \ + sum4 = add_epi32(sum4, mult4); \ + sum5 = add_epi32(sum5, mult5); \ + sum6 = add_epi32(sum6, mult6); \ + sum7 = add_epi32(sum7, mult7); \ + \ + } \ + /* Reduce sums within 128-bit lanes.*/ \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + /*The specific implementation may need to reduce further.*/ \ + auto total = PermuteSummer(pack0123, pack4567); \ + RunCallback(callback_impl, total, 0, B0_colidx, 1, B_cols); \ + } \ +} \ + +//An int8 version of the above code, using the add 127 technique +#define INTGEMM_MULTIPLY8SHIFT(Register, target, cpu_type) \ + template <class Callback> target static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ + assert(width % (sizeof(Register) / sizeof(int8_t)) == 0); \ + assert(B_cols % 8 == 0); \ + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ + const Index simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ + auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ + INTGEMM_OMP_FOR \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ + /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ + const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \ + /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ + Iterate over shared (inner) dimension.*/ \ + Index k = 0; \ + Register a = *(A_row + k); \ + Register sum0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register sum1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register sum2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register sum3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register sum4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register sum5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register sum6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register sum7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is declared here.*/ \ + Register ones = set1_epi16<Register>(1); \ + sum0 = madd_epi16(sum0, ones); \ + sum1 = madd_epi16(sum1, ones); \ + sum2 = madd_epi16(sum2, ones); \ + sum3 = madd_epi16(sum3, ones); \ + sum4 = madd_epi16(sum4, ones); \ + sum5 = madd_epi16(sum5, ones); \ + sum6 = madd_epi16(sum6, ones); \ + sum7 = madd_epi16(sum7, ones); \ + for (k = 1; k < simd_width; ++k) { \ + a = *(A_row + k); \ + /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ \ + Register mult0 = maddubs_epi16(a, *(B0_col + k * 8)); \ + Register mult1 = maddubs_epi16(a, *(B0_col + k * 8 + 1)); \ + Register mult2 = maddubs_epi16(a, *(B0_col + k * 8 + 2)); \ + Register mult3 = maddubs_epi16(a, *(B0_col + k * 8 + 3)); \ + Register mult4 = maddubs_epi16(a, *(B0_col + k * 8 + 4)); \ + Register mult5 = maddubs_epi16(a, *(B0_col + k * 8 + 5)); \ + Register mult6 = maddubs_epi16(a, *(B0_col + k * 8 + 6)); \ + Register mult7 = maddubs_epi16(a, *(B0_col + k * 8 + 7)); \ + /* Upcast to 32-bit and horizontally add.*/ \ + mult0 = madd_epi16(mult0, ones); \ + mult1 = madd_epi16(mult1, ones); \ + mult2 = madd_epi16(mult2, ones); \ + mult3 = madd_epi16(mult3, ones); \ + mult4 = madd_epi16(mult4, ones); \ + mult5 = madd_epi16(mult5, ones); \ + mult6 = madd_epi16(mult6, ones); \ + mult7 = madd_epi16(mult7, ones); \ + /*Add in 32bit*/ \ + sum0 = add_epi32(sum0, mult0); \ + sum1 = add_epi32(sum1, mult1); \ + sum2 = add_epi32(sum2, mult2); \ + sum3 = add_epi32(sum3, mult3); \ + sum4 = add_epi32(sum4, mult4); \ + sum5 = add_epi32(sum5, mult5); \ + sum6 = add_epi32(sum6, mult6); \ + sum7 = add_epi32(sum7, mult7); \ + \ + } \ + /* Reduce sums within 128-bit lanes.*/ \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + /*The specific implementation may need to reduce further.*/ \ + auto total = PermuteSummer(pack0123, pack4567); \ + RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ + } \ + } \ +} \ + +/* 8-bit matrix multiply used by AVX and AVX2. + * These have two peculiar properties: + * 1. The sign instructions don't exist in AVX512. + * 2. 16 registers means gcc's register allocation failed so I wrote it in my + * own asm. + * 3. They support 3-argument vpsignb and vpmaddubsw. + * + * Fun fact: AVX introduced the three-argument vpsignb and vpmaddubsw but only + * for 128-bit, despite the primary change in AVX being the addition of + * 256-bit. We had to wait for INTGEMM_AVX2 to get 256-bit versions of vpsignb and + * vpmaddubsw. That's why this code is generic over 128-bit or 256-bit. + */ +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_AVX2 inline static void InnerINTGEMM_AVX2( + __m256i a, const __m256i *b, + __m256i &sum0, __m256i &sum1, __m256i &sum2, __m256i &sum3, + __m256i &sum4, __m256i &sum5, __m256i &sum6, __m256i &sum7) { + // Annoyingly the only 8-bit multiply is signed * unsigned (maddubs). + // So we take the sign bits off of a and apply them each b in a * b. + // + // We have only 16 YMM registers but we want to store: + // 1 for a (or |a|) + // 8 temporaries for applying sign to each column of B. + // 8 sums. +#if defined(__GNUC__) && !defined(__clang__) + // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663 + // gcc's register allocator does: + // 1 for a, do all the sign application, then overwrite with |a| + // 8 temporaries + // 7 sums in registers + 1 on the stack + // + // But it's possible to complete an operation early, freeing up its + // temporary register for reuse. But completing an operation early + // requires us to have |a| for vpmaddubsw while completing the later + // operation needs a again to apply sign. + // + // So we do two columns, 0 and 1, early. This allows b0_b6 and b1_b7 + // to be reused by columns 6 and 7, respectively. And there's enough + // registers to store both a and |a|. + // + // These are the temporary variables used to process each column of b. + // We let the compiler choose which register number is which, but force + // it to allocate all registers. + __m256i absa; + __m256i b0_b6, b1_b7, b2, b3, b4, b5; + // Maybe this will tell gcc that we're accessing 8 registers starting + // at B_live. Though I doubt it because we're passing the address as a + // register. + typedef struct { __m256i x[8]; } B_range; + asm( + // Copy the first 6 columns of b to registers. We assume B has + // been rearranged so that these 8 columns are consecutive. + // vpsignb does not take a memory address as its second argument, + // so this can't be inlined into vsignb. + "vmovdqa (%[B]), %[b0_b6]\n" + "vmovdqa %c[size](%[B]), %[b1_b7]\n" + // These multiplies are executed by the assembler, not by the CPU + // at run time. + // I would have liked to just initialize b2 etc above but that + // would make it an input argument "+x" instead of "=&x". And +x + // counts as two operands for purposes of gcc's annoying 30-operand + // limit. + "vmovdqa 2*%c[size](%[B]), %[b2]\n" + "vmovdqa 3*%c[size](%[B]), %[b3]\n" + "vmovdqa 4*%c[size](%[B]), %[b4]\n" + "vmovdqa 5*%c[size](%[B]), %[b5]\n" + // Store the absolute value of a in absa. + "vpabsb %[a], %[absa]\n" + // If a byte of a is negative, negate the corresponding byte in + // b0_b6 etc. + "vpsignb %[a], %[b0_b6], %[b0_b6]\n" + "vpsignb %[a], %[b1_b7], %[b1_b7]\n" + // Multiply signed * unsigned then horizontally add to form packed + // 16-bit integers: + // b0[0] * |a|[0] + b0[1] * |a|[1], b0[2] * |a|[2] + b0[3] * |a|[3], ... + "vpmaddubsw %[b0_b6], %[absa], %[b0_b6]\n" + "vpmaddubsw %[b1_b7], %[absa], %[b1_b7]\n" + // vpmaddubsw has latency 5 so work on some other sign bits while + // we're at it. + "vpsignb %[a], %[b2], %[b2]\n" + "vpsignb %[a], %[b3], %[b3]\n" + "vpsignb %[a], %[b4], %[b4]\n" + "vpsignb %[a], %[b5], %[b5]\n" + // Perform a 16-bit add with saturation to accumlate sums. + "vpaddsw %[b0_b6], %[sum0], %[sum0]\n" + // Now we can reuse b0_b6 for b6 + "vmovdqa 6*%c[size](%[B]), %[b0_b6]\n" + "vpaddsw %[b1_b7], %[sum1], %[sum1]\n" + // Now we can reuse b1_b7 for b7 + "vmovdqa 7*%c[size](%[B]), %[b1_b7]\n" + // More crunching while the load happens. + "vpmaddubsw %[b2], %[absa], %[b2]\n" + "vpmaddubsw %[b3], %[absa], %[b3]\n" + "vpmaddubsw %[b4], %[absa], %[b4]\n" + "vpsignb %[a], %[b0_b6], %[b0_b6]\n" + "vpsignb %[a], %[b1_b7], %[b1_b7]\n" + "vpmaddubsw %[b5], %[absa], %[b5]\n" + "vpmaddubsw %[b0_b6], %[absa], %[b0_b6]\n" + "vpmaddubsw %[b1_b7], %[absa], %[b1_b7]\n" + "vpaddsw %[b2], %[sum2], %[sum2]\n" + "vpaddsw %[b3], %[sum3], %[sum3]\n" + "vpaddsw %[b4], %[sum4], %[sum4]\n" + "vpaddsw %[b5], %[sum5], %[sum5]\n" + "vpaddsw %[b0_b6], %[sum6], %[sum6]\n" + "vpaddsw %[b1_b7], %[sum7], %[sum7]\n" + : [sum0] "+x" (sum0), + [sum1] "+x" (sum1), + [sum2] "+x" (sum2), + [sum3] "+x" (sum3), + [sum4] "+x" (sum4), + [sum5] "+x" (sum5), + [sum6] "+x" (sum6), + [sum7] "+x" (sum7), + [b0_b6] "=&x" (b0_b6), + [b1_b7] "=&x" (b1_b7), + [b2] "=&x" (b2), + [b3] "=&x" (b3), + [b4] "=&x" (b4), + [b5] "=&x" (b5), + [absa] "=&x" (absa) + : + // I would like to use m here but that non-deterministically + // chooses %(eax) or -256$(eax) and there's no way to add to that + // memory address: + // https://gcc.gnu.org/ml/gcc-help/2011-04/msg00518.html + // + [B] "r" (reinterpret_cast<const B_range*>(b)), + [a] "x" (a), + [size] "i" (sizeof(__m256i)) + ); +#else + // https://bugs.llvm.org/show_bug.cgi?id=41482 + // clang has a bug: target attribute avx2 doesn't allow inline assembly with + // +x for YMM registers. For example, this will not compile with default + // arguments: + // __attribute__ ((target ("avx2"))) void Foo(__m256i sum0) { + // asm("" : [sum0] "+x" (sum0)); + // } + // but it will compile with -mavx2. + // However, clang does allow intrinsics and has a better register allocator + // than gcc. So here we just use intrinsics. + __m256i a_positive = abs_epi8(a); + sum0 = adds_epi16(sum0, maddubs_epi16(a_positive, sign_epi8(b[0], a))); + sum1 = adds_epi16(sum1, maddubs_epi16(a_positive, sign_epi8(b[1], a))); + sum2 = adds_epi16(sum2, maddubs_epi16(a_positive, sign_epi8(b[2], a))); + sum3 = adds_epi16(sum3, maddubs_epi16(a_positive, sign_epi8(b[3], a))); + sum4 = adds_epi16(sum4, maddubs_epi16(a_positive, sign_epi8(b[4], a))); + sum5 = adds_epi16(sum5, maddubs_epi16(a_positive, sign_epi8(b[5], a))); + sum6 = adds_epi16(sum6, maddubs_epi16(a_positive, sign_epi8(b[6], a))); + sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); +#endif +} +#endif + +// For INTGEMM_SSSE3 without AVX +INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( + __m128i a, const __m128i *b, + __m128i &sum0, __m128i &sum1, __m128i &sum2, __m128i &sum3, + __m128i &sum4, __m128i &sum5, __m128i &sum6, __m128i &sum7) { + __m128i a_positive = abs_epi8(a); + sum0 = adds_epi16(sum0, maddubs_epi16(a_positive, sign_epi8(b[0], a))); + sum1 = adds_epi16(sum1, maddubs_epi16(a_positive, sign_epi8(b[1], a))); + sum2 = adds_epi16(sum2, maddubs_epi16(a_positive, sign_epi8(b[2], a))); + sum3 = adds_epi16(sum3, maddubs_epi16(a_positive, sign_epi8(b[3], a))); + sum4 = adds_epi16(sum4, maddubs_epi16(a_positive, sign_epi8(b[4], a))); + sum5 = adds_epi16(sum5, maddubs_epi16(a_positive, sign_epi8(b[5], a))); + sum6 = adds_epi16(sum6, maddubs_epi16(a_positive, sign_epi8(b[6], a))); + sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); +} +//INTGEMM_AVX2 or INTGEMM_SSSE3 multiply +#define INTGEMM_MULTIPLY8(Register, target, cpu_type) \ + template <typename Callback> target static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { \ + assert(width % sizeof(Register) == 0); \ + assert(B_cols % 8 == 0); \ + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); \ + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ + const Index simd_width = width / sizeof(Register); \ + auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ + INTGEMM_OMP_FOR \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ + /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ + /*Iterate over shared (inner) dimension.*/ \ + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); \ + const Register *A_end = A_live + simd_width; \ + const Register *B_live = B0_col; \ + /* Rather than initializing as zeros and adding, just initialize the first.*/ \ + Register a = *(A_live++); \ + Register a_positive = abs_epi8(a); \ + /* These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.*/ \ + Register sum0 = maddubs_epi16(a_positive, sign_epi8(B_live[0], a)); \ + Register sum1 = maddubs_epi16(a_positive, sign_epi8(B_live[1], a)); \ + Register sum2 = maddubs_epi16(a_positive, sign_epi8(B_live[2], a)); \ + Register sum3 = maddubs_epi16(a_positive, sign_epi8(B_live[3], a)); \ + Register sum4 = maddubs_epi16(a_positive, sign_epi8(B_live[4], a)); \ + Register sum5 = maddubs_epi16(a_positive, sign_epi8(B_live[5], a)); \ + Register sum6 = maddubs_epi16(a_positive, sign_epi8(B_live[6], a)); \ + Register sum7 = maddubs_epi16(a_positive, sign_epi8(B_live[7], a)); \ + B_live += 8; \ + /* Use A as the loop variable so the add can be done where gcc likes it for branch prediction.*/ \ + for (; A_live != A_end; ++A_live, B_live += 8) { \ + Inner##target(*A_live, B_live, sum0, sum1, sum2, sum3, sum4, sum5, sum6, sum7); \ + } \ + /* Convert 16-bit to 32-bit and add, not caring what parts are added. + * Implementations: + * 1. https://github.com/tesseract-ocr/tesseract/blob/master/src/arch/intsimdmatrixavx2.cpp#L67 under Apache license: + * This does a multiply by 1 and horizontal add: + * _mm512_madd_epi16(sum, _mm512_set1_epi16(1)) + * Current fastest. + * + * 2. Signed extension and fold halves: + * sum = _mm512_add_epi32( + * _mm512_cvtepi16_epi32(_mm512_castsi512_si256(sum)), + * _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(sum, 1))); + * + * 3. Sign extend by abuse of bitshift, then add. + * sum = _mm512_add_epi32( + * _mm512_srai_epi32(_mm512_slli_epi32(sum, 16), 16), + * _mm512_srai_epi32(sum, 16)); + */ \ + Register ones = set1_epi16<Register>(1); \ + sum0 = madd_epi16(sum0, ones); \ + sum1 = madd_epi16(sum1, ones); \ + sum2 = madd_epi16(sum2, ones); \ + sum3 = madd_epi16(sum3, ones); \ + sum4 = madd_epi16(sum4, ones); \ + sum5 = madd_epi16(sum5, ones); \ + sum6 = madd_epi16(sum6, ones); \ + sum7 = madd_epi16(sum7, ones); \ + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ + auto total = PermuteSummer(pack0123, pack4567); \ + RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ + } \ + } \ +} + +/* Wrap a multiply call in OMP parallelism. Here it launches threads then + * inside the implementation there is a pragma omp for. In gcc >= 8 these + * could have been the same but older compilers don't imbue target attributes + * on the hidden function created by pragma omp parallel. + * + * Also, gcc 7 is unable to deduce the function pointer type (for ChooseCPU) if + * I use typename Backend::Integer directly in the arguments. As a workaround, + * have a default template argument Integer then use that so it's resolved. + */ +template <class Callback, class Backend, class Integer = typename Backend::Integer> static inline void OMPParallelWrap(const Integer *A, const Integer *B, Index A_rows, Index width, Index B_cols, Callback callback) { +#pragma omp parallel + Backend::template Multiply<Callback>(A, B, A_rows, width, B_cols, callback); +} +template <class Callback, class Backend> static inline void OMPParallelWrap8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { +#pragma omp parallel + Backend::template Multiply8Shift<Callback>(A, B, A_rows, width, B_cols, callback); +} + +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/sse2_gemm.h b/third_party/intgemm/intgemm/sse2_gemm.h new file mode 100644 index 0000000000..cd855a67a9 --- /dev/null +++ b/third_party/intgemm/intgemm/sse2_gemm.h @@ -0,0 +1,84 @@ +#pragma once + +#include "kernels.h" +#include "multiply.h" +#include "types.h" + +#include <cstdint> + +// 8 bit is in ssse3_gemm.h + +namespace intgemm { +namespace SSE2 { + +INTGEMM_SSE2 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) { + return kernels::quantize(loadu_ps<__m128>(input), quant_mult_reg); +} + +INTGEMM_SELECT_COL_B(INTGEMM_SSE2, __m128i) + +class QuantizeTile16 { + public: + INTGEMM_SSE2 static inline Register Consecutive(__m128 mult_reg, const float *input) { + return Tile(mult_reg, input, input + 4); + } + + INTGEMM_SSE2 static inline Register ConsecutiveWithWrapping(__m128 mult_reg, const float *input, Index cols_left, Index cols, Index row_step) { + return Tile(mult_reg, + input, + input + 4 + (cols_left <= 4 ? cols * (row_step - 1) : 0)); + } + + INTGEMM_SSE2 static inline Register ForReshape(__m128 mult_reg, const float *input, int) { + return Consecutive(mult_reg, input); + } + + private: + INTGEMM_SSE2 static inline Register Tile(__m128 mult_reg, const float *input0, const float *input1) { + __m128i g0 = kernels::quantize(loadu_ps<__m128>(input0), mult_reg); + __m128i g1 = kernels::quantize(loadu_ps<__m128>(input1), mult_reg); + return _mm_packs_epi32(g0, g1); + } +}; + +// This should be pure SSE2 (and below). +struct Kernels16 { + typedef int16_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_SSE2 static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + INTGEMM_SSE2 static void Quantize(const float *input, int16_t *output, float quant_mult, Index size) { + assert(size % 8 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 16 == 0); + assert(reinterpret_cast<uintptr_t>(output) % 16 == 0); + FRegister q = set1_ps<FRegister>(quant_mult); + const float *end = input + size; + for (; input != end; input += 8, output += 8) { + *reinterpret_cast<__m128i*>(output) = QuantizeTile16::Consecutive(q, input); + } + } + + // Tile size for B; B must be a multiple of this block size. + static const Index kBTileRow = 8; + static const Index kBTileCol = 8; + + INTGEMM_PREPARE_B_16(INTGEMM_SSE2, QuantizeTile16) + INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_SSE2, int16_t) + INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_SSE2, QuantizeTile16, int16_t) + + INTGEMM_SSE2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + //TODO #DEFINE + SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end); + } + INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::SSE2) + + constexpr static const char *const kName = "16-bit SSE2"; + + static const CPUType kUses = CPUType::SSE2; +}; + +} // namespace SSE2 +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/ssse3_gemm.h b/third_party/intgemm/intgemm/ssse3_gemm.h new file mode 100644 index 0000000000..db403bd06b --- /dev/null +++ b/third_party/intgemm/intgemm/ssse3_gemm.h @@ -0,0 +1,154 @@ +#pragma once + +#include "interleave.h" +#include "kernels.h" +#include "multiply.h" +#include "types.h" + +#include <cstdint> +#include <cstring> + +// 16-bit is in sse2_gemm.h + +namespace intgemm { +namespace SSSE3 { + +INTGEMM_SSSE3 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) { + return kernels::quantize(loadu_ps<__m128>(input), quant_mult_reg); +} + +INTGEMM_SELECT_COL_B(INTGEMM_SSSE3, __m128i) + +class QuantizeTile8 { + public: + INTGEMM_SSSE3 static inline Register ForReshape(FRegister mult_reg, const float *input, Index cols) { + // Skip a row. + return Tile(mult_reg, input, input + 4, input + 2 * cols, input + 2 * cols + 4); + } + + INTGEMM_SSSE3 static inline Register Consecutive(FRegister mult_reg, const float *input) { + return Tile(mult_reg, input, input + 4, input + 8, input + 12); + } + + INTGEMM_SSSE3 static inline Register ConsecutiveU(FRegister mult_reg, const float *input) { + return TileU(mult_reg, input, input + 4, input + 8, input + 12); + } + + INTGEMM_SSSE3 static inline Register ConsecutiveWithWrapping(FRegister mult_reg, const float *input, Index cols_left, Index cols, Index row_step) { + const float* inputs[4]; + for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { + while (cols_left < sizeof(Register) / sizeof(float)) { + input += cols * (row_step - 1); + cols_left += cols; + } + inputs[i] = input; + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); + } + return Tile(mult_reg, inputs[0], inputs[1], inputs[2], inputs[3]); + } + + // Quantize 16xfloat into 16xint8_t + INTGEMM_SSSE3 static inline __m128i Tile(FRegister mult_reg, const float *input0, const float *input1, const float *input2, const float *input3) { + const __m128i neg128 = _mm_set1_epi8(-128); + __m128i g0 = QuantizerGrab(input0, mult_reg); + __m128i g1 = QuantizerGrab(input1, mult_reg); + __m128i g2 = QuantizerGrab(input2, mult_reg); + __m128i g3 = QuantizerGrab(input3, mult_reg); + __m128i packed0 = _mm_packs_epi32(g0, g1); + __m128i packed1 = _mm_packs_epi32(g2, g3); + __m128i packed = _mm_packs_epi16(packed0, packed1); + /* Ban -128. + * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead, + * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8. + * The first generates 0xff for fields -128. + * The second subtracts 0xff from -128 which has the effect of converting + * to -127. + */ + // packed = _mm_max_epi8(packed, neg127); + __m128i evils = _mm_cmpeq_epi8(packed, neg128); + return _mm_sub_epi8(packed, evils); + // No permute needed. packs is in order for SSE. + } + + private: + INTGEMM_SSSE3 static inline __m128i TileU(FRegister mult_reg, const float *input0, const float *input1, const float *input2, const float *input3) { + const __m128i neg128 = _mm_set1_epi8(-128); + const __m128i pos127 = _mm_set1_epi8(127); + __m128i g0 = QuantizerGrab(input0, mult_reg); + __m128i g1 = QuantizerGrab(input1, mult_reg); + __m128i g2 = QuantizerGrab(input2, mult_reg); + __m128i g3 = QuantizerGrab(input3, mult_reg); + __m128i packed0 = _mm_packs_epi32(g0, g1); + __m128i packed1 = _mm_packs_epi32(g2, g3); + __m128i packed = _mm_packs_epi16(packed0, packed1); + /* Ban -128. + * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead, + * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8. + * The first generates 0xff for fields -128. + * The second subtracts 0xff from -128 which has the effect of converting + * to -127. + */ + // packed = _mm_max_epi8(packed, neg127); + __m128i evils = _mm_cmpeq_epi8(packed, neg128); + return _mm_add_epi8(_mm_sub_epi8(packed, evils), pos127); + // No permute needed. packs is in order for SSE. + } +}; + +// pmaddubsw (the 8-bit multiply) is SSSE3, so pedantically that's the version we need. +struct Kernels8 { + typedef int8_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_SSSE3 static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + private: + INTGEMM_QUANTIZE_THREAD(INTGEMM_SSSE3) + public: + INTGEMM_QUANTIZE(INTGEMM_SSSE3) + + // Version with unsigned int + 127 + // Currently A is prepared by quantization but this could theoretically change. + INTGEMM_SSSE3 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) { + QuantizeU(input, output, quant_mult, rows * cols); + } + + INTGEMM_SSSE3 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) { + assert(size % 16 == 0); + assert(reinterpret_cast<uintptr_t>(input) % 16 == 0); + assert(reinterpret_cast<uintptr_t>(output) % 16 == 0); + FRegister q = set1_ps<FRegister>(quant_mult); + const float *end = input + size; + for (; input != end; input += 16, output += 16) { + *reinterpret_cast<__m128i*>(output) = QuantizeTile8::ConsecutiveU(q, input); + } + } + + // Tile size for B; B must be a multiple of this block size. + static const Index kBTileRow = 16; + static const Index kBTileCol = 8; + + INTGEMM_PREPARE_B_8(INTGEMM_SSSE3, SSSE3::QuantizeTile8) + INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_SSSE3, int8_t) + INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_SSSE3, QuantizeTile8, int8_t) + + INTGEMM_SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { + SSSE3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end); + } + + INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + + INTGEMM_MULTIPLY8SHIFT(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + + INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + + constexpr static const char *const kName = "8-bit SSSE3"; + + static const CPUType kUses = CPUType::SSSE3; +}; + +} // namespace SSSE3 +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/stats.h b/third_party/intgemm/intgemm/stats.h new file mode 100644 index 0000000000..9573c4b9ee --- /dev/null +++ b/third_party/intgemm/intgemm/stats.h @@ -0,0 +1,76 @@ +#pragma once + +#include <cmath> +#include "intrinsics.h" + +#ifdef _OPENMP +#include <omp.h> +#endif + +namespace intgemm { + +/* Horizontal max and sums. TODO make a template argument? */ + +INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { + // Fold to just using the first 64 bits. + __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); + a = _mm_max_ps(a, second_half); + // Fold to just using the first 32 bits. + second_half = _mm_shuffle_ps(a, a, 1); + a = _mm_max_ps(a, second_half); + // This casting compiles to nothing. + return *reinterpret_cast<float*>(&a); +} +INTGEMM_SSE2 static inline float AddFloat32(__m128 a) { + // Fold to just using the first 64 bits. + __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); + a = _mm_add_ps(a, second_half); + // Fold to just using the first 32 bits. + second_half = _mm_shuffle_ps(a, a, 1); + a = _mm_add_ps(a, second_half); + // This casting compiles to nothing. + return *reinterpret_cast<float*>(&a); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) { + return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))); +} +INTGEMM_AVX2 static inline float AddFloat32(__m256 a) { + return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +// Find the maximum float. +INTGEMM_AVX512F static inline float MaxFloat32(__m512 a) { + // _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking. + // So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps. + __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1)); + return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper)); +} +INTGEMM_AVX512F static inline float AddFloat32(__m512 a) { + __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1)); + return AddFloat32(add_ps(_mm512_castps512_ps256(a), upper)); +} +#endif + +constexpr int32_t kFloatAbsoluteMask = 0x7fffffff; + +} // namespace intgemm + +#define INTGEMM_THIS_IS_SSE2 +#include "stats.inl" +#undef INTGEMM_THIS_IS_SSE2 + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +#define INTGEMM_THIS_IS_AVX2 +#include "stats.inl" +#undef INTGEMM_THIS_IS_AVX2 +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +#define INTGEMM_THIS_IS_AVX512DQ +#include "stats.inl" +#undef INTGEMM_THIS_IS_AVX512DQ +#endif diff --git a/third_party/intgemm/intgemm/stats.inl b/third_party/intgemm/intgemm/stats.inl new file mode 100644 index 0000000000..68a5b8e150 --- /dev/null +++ b/third_party/intgemm/intgemm/stats.inl @@ -0,0 +1,98 @@ +/* This file is included multiple times, once per architecture. */ +#if defined(INTGEMM_THIS_IS_AVX512DQ) +#define INTGEMM_ARCH AVX512BW +#define INTGEMM_TARGET INTGEMM_AVX512DQ +#elif defined(INTGEMM_THIS_IS_AVX2) +#define INTGEMM_ARCH AVX2 +#define INTGEMM_TARGET INTGEMM_AVX2 +#elif defined(INTGEMM_THIS_IS_SSE2) +#define INTGEMM_ARCH SSE2 +#define INTGEMM_TARGET INTGEMM_SSE2 +#else +#error Included with unexpected architecture +#endif + +namespace intgemm { +namespace INTGEMM_ARCH { + +/* Compute the maximum absolute value over floats aligned to register size. + * Do not call this function directly; it's a subroutine of MaxAbsolute. + */ +INTGEMM_TARGET static inline float MaxAbsoluteThread(const FRegister *begin, const FRegister *end) { + FRegister highest = setzero_ps<FRegister>(); + const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask)); +#pragma omp for + for (const FRegister *i = begin; i < end; ++i) { + FRegister reg = and_ps(abs_mask, *i); + highest = max_ps(highest, reg); + } + return MaxFloat32(highest); +} + +/* Compute the maximum absolute value of an array of floats. + * begin_float must be aligned to a multiple of the register size. +*/ +INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const float *end_float) { + assert(reinterpret_cast<uintptr_t>(begin_float) % sizeof(FRegister) == 0); + const float *end_reg = end_float - (reinterpret_cast<uintptr_t>(end_float) % sizeof(FRegister)) / sizeof(float); + float ret = 0.0; +#pragma omp parallel reduction(max:ret) num_threads(std::max<int>(1, std::min<int>(omp_get_max_threads(), (end_float - begin_float) / 16384))) + { + float shard_max = MaxAbsoluteThread( + reinterpret_cast<const FRegister*>(begin_float), + reinterpret_cast<const FRegister*>(end_reg)); + ret = std::max(ret, shard_max); + } + /* Overhang. The beginning was aligned so if there's any overhang we're + * allowed to read the next full register. Then mask that to 0. */ +#if defined(INTGEMM_THIS_IS_AVX512DQ) + if (end_float != end_reg) { + const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask)); + __mmask16 mask = (1 << (end_float - end_reg)) - 1; + FRegister masked = _mm512_maskz_and_ps(mask, abs_mask, *reinterpret_cast<const FRegister*>(end_reg)); + ret = std::max(ret, MaxFloat32(masked)); + } +#else + for (const float *i = end_reg; i < end_float; ++i) { + ret = std::max(ret, std::fabs(*i)); + } +#endif + return ret; +} + +/* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */ +INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) { + assert(end_float > begin_float); + assert((end_float - begin_float) % (sizeof(FRegister) / sizeof(float)) == 0); + size_t num_items = end_float - begin_float; + const FRegister *begin = reinterpret_cast<const FRegister*>(begin_float); + const FRegister *end = reinterpret_cast<const FRegister*>(end_float); + FRegister squares = set1_ps<FRegister>(0); + FRegister sums = set1_ps<FRegister>(0); + if (absolute) { + const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask)); + for (; begin != end; begin++) { + FRegister vec = and_ps(abs_mask, *begin); + squares = add_ps(squares, mul_ps(vec, vec)); + sums = add_ps(sums, vec); + } + } else { + for (; begin != end; begin++) { + FRegister vec = *begin; + squares = add_ps(squares, mul_ps(vec, vec)); + sums = add_ps(sums, vec); + } + } + float squares_sum = AddFloat32(squares); + float normal_sums = AddFloat32(sums); + MeanStd ret; + ret.mean = normal_sums/num_items; + ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean)); + return ret; +} + +} // namespace INTGEMM_ARCH +} // namespace intgemm + +#undef INTGEMM_ARCH +#undef INTGEMM_TARGET diff --git a/third_party/intgemm/intgemm/types.h b/third_party/intgemm/intgemm/types.h new file mode 100644 index 0000000000..44fb4e2293 --- /dev/null +++ b/third_party/intgemm/intgemm/types.h @@ -0,0 +1,118 @@ +#pragma once +#include "intgemm/intgemm_config.h" + +#include <exception> +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +#include <immintrin.h> +#endif +#include <emmintrin.h> + +// clang-cl bug doesn't include these headers when pretending to be MSVC +// https://github.com/llvm/llvm-project/blob/e9a294449575a1e1a0daca470f64914695dc9adc/clang/lib/Headers/immintrin.h#L69-L72 +#if defined(_MSC_VER) && defined(__clang__) +#include <avxintrin.h> +#include <avx2intrin.h> +#include <smmintrin.h> +#include <avx512fintrin.h> +#include <avx512dqintrin.h> +#include <avx512bwintrin.h> +#include <avx512vnniintrin.h> +#endif + +#if (defined(_MSC_VER) && !defined(__clang__)) || defined(__INTEL_COMPILER) +/* Real MSVC does not appear to have target attributes but is also fine with + * just using intrinsics anywhere. clang-cl pretending to be MSVC requires + * target attributes, so it's excluded from the above. + * + * The Intel compiler has a bug whereby constructors with target attributes do + * not link. Like this program doesn't compile with icpc: + * class Foo { + * public: + * __attribute__ ((target ("avx2"))) Foo() {} + * }; + * int main() { Foo a; } + * + * It appears to be erroneously activating function multiversioning when only + * one version of a constructor with target attributes is defined. Normal + * methods with one target attribute work fine. The Intel compiler also allows + * intrinsics without any target attributes so we just leave them blank. + */ + #define INTGEMM_SSE2 + #define INTGEMM_SSSE3 + #define INTGEMM_AVX2 + #define INTGEMM_AVX512F + #define INTGEMM_AVX512BW + #define INTGEMM_AVX512DQ + #define INTGEMM_AVX512VNNI +#else + /* gcc and clang take lists of all the flavors */ + #define INTGEMM_SSE2 __attribute__ ((target ("sse2"))) + #define INTGEMM_SSSE3 __attribute__ ((target ("ssse3"))) + #define INTGEMM_AVX2 __attribute__ ((target ("avx2"))) + #define INTGEMM_AVX512F __attribute__ ((target ("avx512f"))) + #define INTGEMM_AVX512BW __attribute__ ((target ("avx512f,avx512bw,avx512dq"))) + #define INTGEMM_AVX512DQ __attribute__ ((target ("avx512f,avx512bw,avx512dq"))) + #define INTGEMM_AVX512VNNI __attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni"))) +#endif +namespace intgemm { + +// This will be thrown if a CPU isn't supported by the routines (16-bit without SSE2 or 8-bit without SSSE3). +class UnsupportedCPU : public std::exception { + public: + UnsupportedCPU() {} + + ~UnsupportedCPU() throw() {} + + const char *what() const throw() override { + return "Integer matrix multiplication has not been efficiently implemented for your CPU."; + } +}; + +typedef unsigned int Index; + +// If you want to detect the CPU and dispatch yourself, here's what to use: +enum class CPUType { + UNSUPPORTED = 0, + SSE2 = 1, + SSSE3 = 2, + AVX2 = 3, + AVX512BW = 4, + AVX512VNNI = 5 +}; + +// Running CPU type. This is defined in intgemm.cc (as the dispatcher). +extern const CPUType kCPU; + +struct MeanStd { + float mean; + float stddev; +}; + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +namespace AVX512VNNI { +typedef __m512i Register; +typedef __m512 FRegister; +} // namespace AVX512VNNI +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +namespace AVX512BW { +typedef __m512i Register; +typedef __m512 FRegister; +} // namespace AVX512BW +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +namespace AVX2 { +typedef __m256i Register; +typedef __m256 FRegister; +} // namespace AVX2 +#endif +namespace SSSE3 { +typedef __m128i Register; +typedef __m128 FRegister; +} // namespace SSSE3 +namespace SSE2 { +typedef __m128i Register; +typedef __m128 FRegister; +} // namespace SSE2 + +} // namespace intgemm diff --git a/third_party/intgemm/intgemm/utils.h b/third_party/intgemm/intgemm/utils.h new file mode 100644 index 0000000000..a520ea0c0f --- /dev/null +++ b/third_party/intgemm/intgemm/utils.h @@ -0,0 +1,82 @@ +#pragma once + +#include <tuple> + +namespace intgemm { + +/* + * Sequence of unsigned integers + * + * Examples: + * sequence<1, 2, 3>() + * sequence_pushback<4, sequence<1, 2, 3>>() = sequence<1, 2, 3, 4>() + * sequence_popfront<sequence<1, 2, 3>>() = sequence<2, 3>() + * make_sequence<3>() = sequence<0, 1, 2>() + */ +template <unsigned... Indices> +struct sequence { using type = sequence; }; + +template <unsigned I, typename Sequence> +struct sequence_pushback; + +template <unsigned I, unsigned... Indices> +struct sequence_pushback<I, sequence<Indices...>> : sequence<Indices..., I> {}; + +template <typename Sequence> +struct sequence_popfront; + +template <unsigned FirstIndex, unsigned... RestIndices> +struct sequence_popfront<sequence<FirstIndex, RestIndices...>> : sequence<RestIndices...> {}; + +namespace { // anonymous namespace +template <unsigned N> +struct make_sequence_impl : sequence_pushback<N - 1, typename make_sequence_impl<N - 1>::type> {}; +template <> +struct make_sequence_impl<0> : sequence<> {}; +} // anonymous namespace + +template <unsigned N> +using make_sequence = typename make_sequence_impl<N>::type; + +/* + * Make a subtuple + */ +template <typename Tuple, unsigned... Indices> +using subtuple_t = typename std::tuple<typename std::tuple_element<Indices, Tuple>::type...>; + +template <typename Tuple, unsigned... Indices> +constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequence<Indices...>) { + return std::make_tuple(std::get<Indices>(tuple)...); +} + +/* + * Factorial + */ +static constexpr unsigned long long factorial(unsigned n) { + return n <= 1 ? 1 : n * factorial(n - 1); +} + +/* + * e^n, where n is integer + */ +static constexpr double expi_nonnegative(unsigned n) { + return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2)); +} + +static constexpr double expi(int n) { + return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n)); +} + +// Version that returns float. +static constexpr float expif(int n) { + return static_cast<float>(expi(n)); +} + +/* + * Round up + */ +static constexpr Index round_up(Index value, Index factor) { + return (value + factor - 1) / factor * factor; +} + +} diff --git a/third_party/intgemm/intgemm/vec_traits.h b/third_party/intgemm/intgemm/vec_traits.h new file mode 100644 index 0000000000..948dae1f21 --- /dev/null +++ b/third_party/intgemm/intgemm/vec_traits.h @@ -0,0 +1,57 @@ +#pragma once + +#include "types.h" + +namespace intgemm { + +/* + * Vector traits + */ +template <CPUType CPUType_, typename ElemType_> struct vector_s; +template <> struct vector_s<CPUType::SSE2, int8_t> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSE2, int16_t> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSE2, int> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSE2, float> { using type = __m128; }; +template <> struct vector_s<CPUType::SSE2, double> { using type = __m128d; }; +template <> struct vector_s<CPUType::SSSE3, int8_t> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSSE3, int16_t> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSSE3, int> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSSE3, float> { using type = __m128; }; +template <> struct vector_s<CPUType::SSSE3, double> { using type = __m128d; }; +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template <> struct vector_s<CPUType::AVX2, int8_t> { using type = __m256i; }; +template <> struct vector_s<CPUType::AVX2, int16_t> { using type = __m256i; }; +template <> struct vector_s<CPUType::AVX2, int> { using type = __m256i; }; +template <> struct vector_s<CPUType::AVX2, float> { using type = __m256; }; +template <> struct vector_s<CPUType::AVX2, double> { using type = __m256d; }; +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template <> struct vector_s<CPUType::AVX512BW, int8_t> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512BW, int16_t> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; }; +template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; }; +#endif + +template <CPUType CPUType_, typename ElemType_> +using vector_t = typename vector_s<CPUType_, ElemType_>::type; + +template <CPUType CPUType_, typename ElemType_> +struct dvector_t { + using type = vector_t<CPUType_, ElemType_>; + + type first; + type second; +}; + +template <CPUType CPUType_, typename ElemType_> +struct qvector_t { + using type = vector_t<CPUType_, ElemType_>; + + type first; + type second; + type third; + type fourth; +}; + +} diff --git a/third_party/intgemm/test/3rd_party/LICENSE_1_0.txt b/third_party/intgemm/test/3rd_party/LICENSE_1_0.txt new file mode 100644 index 0000000000..7925d62e6b --- /dev/null +++ b/third_party/intgemm/test/3rd_party/LICENSE_1_0.txt @@ -0,0 +1,24 @@ +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + diff --git a/third_party/intgemm/test/3rd_party/catch.hpp b/third_party/intgemm/test/3rd_party/catch.hpp new file mode 100644 index 0000000000..1850fff125 --- /dev/null +++ b/third_party/intgemm/test/3rd_party/catch.hpp @@ -0,0 +1,14934 @@ +/* + * Catch v2.7.0 + * Generated: 2019-03-07 21:34:30.252164 + * ---------------------------------------------------------- + * This file has been merged from multiple headers. Please don't edit it directly + * Copyright (c) 2019 Two Blue Cubes Ltd. All rights reserved. + * + * Distributed under the Boost Software License, Version 1.0. (See accompanying + * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + */ +#ifndef TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED +#define TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED +// start catch.hpp + + +#define CATCH_VERSION_MAJOR 2 +#define CATCH_VERSION_MINOR 7 +#define CATCH_VERSION_PATCH 0 + +#ifdef __clang__ +# pragma clang system_header +#elif defined __GNUC__ +# pragma GCC system_header +#endif + +// start catch_suppress_warnings.h + +#ifdef __clang__ +# ifdef __ICC // icpc defines the __clang__ macro +# pragma warning(push) +# pragma warning(disable: 161 1682) +# else // __ICC +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpadded" +# pragma clang diagnostic ignored "-Wswitch-enum" +# pragma clang diagnostic ignored "-Wcovered-switch-default" +# endif +#elif defined __GNUC__ + // Because REQUIREs trigger GCC's -Wparentheses, and because still + // supported version of g++ have only buggy support for _Pragmas, + // Wparentheses have to be suppressed globally. +# pragma GCC diagnostic ignored "-Wparentheses" // See #674 for details + +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wunused-variable" +# pragma GCC diagnostic ignored "-Wpadded" +#endif +// end catch_suppress_warnings.h +#if defined(CATCH_CONFIG_MAIN) || defined(CATCH_CONFIG_RUNNER) +# define CATCH_IMPL +# define CATCH_CONFIG_ALL_PARTS +#endif + +// In the impl file, we want to have access to all parts of the headers +// Can also be used to sanely support PCHs +#if defined(CATCH_CONFIG_ALL_PARTS) +# define CATCH_CONFIG_EXTERNAL_INTERFACES +# if defined(CATCH_CONFIG_DISABLE_MATCHERS) +# undef CATCH_CONFIG_DISABLE_MATCHERS +# endif +# if !defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER) +# define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER +# endif +#endif + +#if !defined(CATCH_CONFIG_IMPL_ONLY) +// start catch_platform.h + +#ifdef __APPLE__ +# include <TargetConditionals.h> +# if TARGET_OS_OSX == 1 +# define CATCH_PLATFORM_MAC +# elif TARGET_OS_IPHONE == 1 +# define CATCH_PLATFORM_IPHONE +# endif + +#elif defined(linux) || defined(__linux) || defined(__linux__) +# define CATCH_PLATFORM_LINUX + +#elif defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) || defined(__MINGW32__) +# define CATCH_PLATFORM_WINDOWS +#endif + +// end catch_platform.h + +#ifdef CATCH_IMPL +# ifndef CLARA_CONFIG_MAIN +# define CLARA_CONFIG_MAIN_NOT_DEFINED +# define CLARA_CONFIG_MAIN +# endif +#endif + +// start catch_user_interfaces.h + +namespace Catch { + unsigned int rngSeed(); +} + +// end catch_user_interfaces.h +// start catch_tag_alias_autoregistrar.h + +// start catch_common.h + +// start catch_compiler_capabilities.h + +// Detect a number of compiler features - by compiler +// The following features are defined: +// +// CATCH_CONFIG_COUNTER : is the __COUNTER__ macro supported? +// CATCH_CONFIG_WINDOWS_SEH : is Windows SEH supported? +// CATCH_CONFIG_POSIX_SIGNALS : are POSIX signals supported? +// CATCH_CONFIG_DISABLE_EXCEPTIONS : Are exceptions enabled? +// **************** +// Note to maintainers: if new toggles are added please document them +// in configuration.md, too +// **************** + +// In general each macro has a _NO_<feature name> form +// (e.g. CATCH_CONFIG_NO_POSIX_SIGNALS) which disables the feature. +// Many features, at point of detection, define an _INTERNAL_ macro, so they +// can be combined, en-mass, with the _NO_ forms later. + +#ifdef __cplusplus + +# if (__cplusplus >= 201402L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +# define CATCH_CPP14_OR_GREATER +# endif + +# if (__cplusplus >= 201703L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +# define CATCH_CPP17_OR_GREATER +# endif + +#endif + +#if defined(CATCH_CPP17_OR_GREATER) +# define CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS +#endif + +#ifdef __clang__ + +# define CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wexit-time-destructors\"" ) \ + _Pragma( "clang diagnostic ignored \"-Wglobal-constructors\"") +# define CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wparentheses\"" ) +# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +# define CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wunused-variable\"" ) +# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +#endif // __clang__ + +//////////////////////////////////////////////////////////////////////////////// +// Assume that non-Windows platforms support posix signals by default +#if !defined(CATCH_PLATFORM_WINDOWS) + #define CATCH_INTERNAL_CONFIG_POSIX_SIGNALS +#endif + +//////////////////////////////////////////////////////////////////////////////// +// We know some environments not to support full POSIX signals +#if defined(__CYGWIN__) || defined(__QNX__) || defined(__EMSCRIPTEN__) || defined(__DJGPP__) + #define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS +#endif + +#ifdef __OS400__ +# define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS +# define CATCH_CONFIG_COLOUR_NONE +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Android somehow still does not support std::to_string +#if defined(__ANDROID__) +# define CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Not all Windows environments support SEH properly +#if defined(__MINGW32__) +# define CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH +#endif + +//////////////////////////////////////////////////////////////////////////////// +// PS4 +#if defined(__ORBIS__) +# define CATCH_INTERNAL_CONFIG_NO_NEW_CAPTURE +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Cygwin +#ifdef __CYGWIN__ + +// Required for some versions of Cygwin to declare gettimeofday +// see: http://stackoverflow.com/questions/36901803/gettimeofday-not-declared-in-this-scope-cygwin +# define _BSD_SOURCE +// some versions of cygwin (most) do not support std::to_string. Use the libstd check. +// https://gcc.gnu.org/onlinedocs/gcc-4.8.2/libstdc++/api/a01053_source.html line 2812-2813 +# if !((__cplusplus >= 201103L) && defined(_GLIBCXX_USE_C99) \ + && !defined(_GLIBCXX_HAVE_BROKEN_VSWPRINTF)) + +# define CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING + +# endif +#endif // __CYGWIN__ + +//////////////////////////////////////////////////////////////////////////////// +// Visual C++ +#ifdef _MSC_VER + +# if _MSC_VER >= 1900 // Visual Studio 2015 or newer +# define CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS +# endif + +// Universal Windows platform does not support SEH +// Or console colours (or console at all...) +# if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +# define CATCH_CONFIG_COLOUR_NONE +# else +# define CATCH_INTERNAL_CONFIG_WINDOWS_SEH +# endif + +// MSVC traditional preprocessor needs some workaround for __VA_ARGS__ +// _MSVC_TRADITIONAL == 0 means new conformant preprocessor +// _MSVC_TRADITIONAL == 1 means old traditional non-conformant preprocessor +# if !defined(_MSVC_TRADITIONAL) || (defined(_MSVC_TRADITIONAL) && _MSVC_TRADITIONAL) +# define CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +# endif + +#endif // _MSC_VER + +//////////////////////////////////////////////////////////////////////////////// +// Check if we are compiled with -fno-exceptions or equivalent +#if defined(__EXCEPTIONS) || defined(__cpp_exceptions) || defined(_CPPUNWIND) +# define CATCH_INTERNAL_CONFIG_EXCEPTIONS_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////// +// DJGPP +#ifdef __DJGPP__ +# define CATCH_INTERNAL_CONFIG_NO_WCHAR +#endif // __DJGPP__ + +//////////////////////////////////////////////////////////////////////////////// +// Embarcadero C++Build +#if defined(__BORLANDC__) + #define CATCH_INTERNAL_CONFIG_POLYFILL_ISNAN +#endif + +//////////////////////////////////////////////////////////////////////////////// + +// Use of __COUNTER__ is suppressed during code analysis in +// CLion/AppCode 2017.2.x and former, because __COUNTER__ is not properly +// handled by it. +// Otherwise all supported compilers support COUNTER macro, +// but user still might want to turn it off +#if ( !defined(__JETBRAINS_IDE__) || __JETBRAINS_IDE__ >= 20170300L ) + #define CATCH_INTERNAL_CONFIG_COUNTER +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Check if string_view is available and usable +// The check is split apart to work around v140 (VS2015) preprocessor issue... +#if defined(__has_include) +#if __has_include(<string_view>) && defined(CATCH_CPP17_OR_GREATER) +# define CATCH_INTERNAL_CONFIG_CPP17_STRING_VIEW +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Check if optional is available and usable +#if defined(__has_include) +# if __has_include(<optional>) && defined(CATCH_CPP17_OR_GREATER) +# define CATCH_INTERNAL_CONFIG_CPP17_OPTIONAL +# endif // __has_include(<optional>) && defined(CATCH_CPP17_OR_GREATER) +#endif // __has_include + +//////////////////////////////////////////////////////////////////////////////// +// Check if variant is available and usable +#if defined(__has_include) +# if __has_include(<variant>) && defined(CATCH_CPP17_OR_GREATER) +# if defined(__clang__) && (__clang_major__ < 8) + // work around clang bug with libstdc++ https://bugs.llvm.org/show_bug.cgi?id=31852 + // fix should be in clang 8, workaround in libstdc++ 8.2 +# include <ciso646> +# if defined(__GLIBCXX__) && defined(_GLIBCXX_RELEASE) && (_GLIBCXX_RELEASE < 9) +# define CATCH_CONFIG_NO_CPP17_VARIANT +# else +# define CATCH_INTERNAL_CONFIG_CPP17_VARIANT +# endif // defined(__GLIBCXX__) && defined(_GLIBCXX_RELEASE) && (_GLIBCXX_RELEASE < 9) +# else +# define CATCH_INTERNAL_CONFIG_CPP17_VARIANT +# endif // defined(__clang__) && (__clang_major__ < 8) +# endif // __has_include(<variant>) && defined(CATCH_CPP17_OR_GREATER) +#endif // __has_include + +#if defined(CATCH_INTERNAL_CONFIG_COUNTER) && !defined(CATCH_CONFIG_NO_COUNTER) && !defined(CATCH_CONFIG_COUNTER) +# define CATCH_CONFIG_COUNTER +#endif +#if defined(CATCH_INTERNAL_CONFIG_WINDOWS_SEH) && !defined(CATCH_CONFIG_NO_WINDOWS_SEH) && !defined(CATCH_CONFIG_WINDOWS_SEH) && !defined(CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH) +# define CATCH_CONFIG_WINDOWS_SEH +#endif +// This is set by default, because we assume that unix compilers are posix-signal-compatible by default. +#if defined(CATCH_INTERNAL_CONFIG_POSIX_SIGNALS) && !defined(CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_POSIX_SIGNALS) +# define CATCH_CONFIG_POSIX_SIGNALS +#endif +// This is set by default, because we assume that compilers with no wchar_t support are just rare exceptions. +#if !defined(CATCH_INTERNAL_CONFIG_NO_WCHAR) && !defined(CATCH_CONFIG_NO_WCHAR) && !defined(CATCH_CONFIG_WCHAR) +# define CATCH_CONFIG_WCHAR +#endif + +#if !defined(CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING) && !defined(CATCH_CONFIG_NO_CPP11_TO_STRING) && !defined(CATCH_CONFIG_CPP11_TO_STRING) +# define CATCH_CONFIG_CPP11_TO_STRING +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_OPTIONAL) && !defined(CATCH_CONFIG_NO_CPP17_OPTIONAL) && !defined(CATCH_CONFIG_CPP17_OPTIONAL) +# define CATCH_CONFIG_CPP17_OPTIONAL +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) && !defined(CATCH_CONFIG_NO_CPP17_UNCAUGHT_EXCEPTIONS) && !defined(CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) +# define CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_STRING_VIEW) && !defined(CATCH_CONFIG_NO_CPP17_STRING_VIEW) && !defined(CATCH_CONFIG_CPP17_STRING_VIEW) +# define CATCH_CONFIG_CPP17_STRING_VIEW +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_VARIANT) && !defined(CATCH_CONFIG_NO_CPP17_VARIANT) && !defined(CATCH_CONFIG_CPP17_VARIANT) +# define CATCH_CONFIG_CPP17_VARIANT +#endif + +#if defined(CATCH_CONFIG_EXPERIMENTAL_REDIRECT) +# define CATCH_INTERNAL_CONFIG_NEW_CAPTURE +#endif + +#if defined(CATCH_INTERNAL_CONFIG_NEW_CAPTURE) && !defined(CATCH_INTERNAL_CONFIG_NO_NEW_CAPTURE) && !defined(CATCH_CONFIG_NO_NEW_CAPTURE) && !defined(CATCH_CONFIG_NEW_CAPTURE) +# define CATCH_CONFIG_NEW_CAPTURE +#endif + +#if !defined(CATCH_INTERNAL_CONFIG_EXCEPTIONS_ENABLED) && !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) +# define CATCH_CONFIG_DISABLE_EXCEPTIONS +#endif + +#if defined(CATCH_INTERNAL_CONFIG_POLYFILL_ISNAN) && !defined(CATCH_CONFIG_NO_POLYFILL_ISNAN) && !defined(CATCH_CONFIG_POLYFILL_ISNAN) +# define CATCH_CONFIG_POLYFILL_ISNAN +#endif + +#if !defined(CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS +#endif +#if !defined(CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS +#endif +#if !defined(CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS +#endif + +#if defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) +#define CATCH_TRY if ((true)) +#define CATCH_CATCH_ALL if ((false)) +#define CATCH_CATCH_ANON(type) if ((false)) +#else +#define CATCH_TRY try +#define CATCH_CATCH_ALL catch (...) +#define CATCH_CATCH_ANON(type) catch (type) +#endif + +#if defined(CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR) && !defined(CATCH_CONFIG_NO_TRADITIONAL_MSVC_PREPROCESSOR) && !defined(CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR) +#define CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#endif + +// end catch_compiler_capabilities.h +#define INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) name##line +#define INTERNAL_CATCH_UNIQUE_NAME_LINE( name, line ) INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) +#ifdef CATCH_CONFIG_COUNTER +# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __COUNTER__ ) +#else +# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __LINE__ ) +#endif + +#include <iosfwd> +#include <string> +#include <cstdint> + +// We need a dummy global operator<< so we can bring it into Catch namespace later +struct Catch_global_namespace_dummy {}; +std::ostream& operator<<(std::ostream&, Catch_global_namespace_dummy); + +namespace Catch { + + struct CaseSensitive { enum Choice { + Yes, + No + }; }; + + class NonCopyable { + NonCopyable( NonCopyable const& ) = delete; + NonCopyable( NonCopyable && ) = delete; + NonCopyable& operator = ( NonCopyable const& ) = delete; + NonCopyable& operator = ( NonCopyable && ) = delete; + + protected: + NonCopyable(); + virtual ~NonCopyable(); + }; + + struct SourceLineInfo { + + SourceLineInfo() = delete; + SourceLineInfo( char const* _file, std::size_t _line ) noexcept + : file( _file ), + line( _line ) + {} + + SourceLineInfo( SourceLineInfo const& other ) = default; + SourceLineInfo& operator = ( SourceLineInfo const& ) = default; + SourceLineInfo( SourceLineInfo&& ) noexcept = default; + SourceLineInfo& operator = ( SourceLineInfo&& ) noexcept = default; + + bool empty() const noexcept; + bool operator == ( SourceLineInfo const& other ) const noexcept; + bool operator < ( SourceLineInfo const& other ) const noexcept; + + char const* file; + std::size_t line; + }; + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ); + + // Bring in operator<< from global namespace into Catch namespace + // This is necessary because the overload of operator<< above makes + // lookup stop at namespace Catch + using ::operator<<; + + // Use this in variadic streaming macros to allow + // >> +StreamEndStop + // as well as + // >> stuff +StreamEndStop + struct StreamEndStop { + std::string operator+() const; + }; + template<typename T> + T const& operator + ( T const& value, StreamEndStop ) { + return value; + } +} + +#define CATCH_INTERNAL_LINEINFO \ + ::Catch::SourceLineInfo( __FILE__, static_cast<std::size_t>( __LINE__ ) ) + +// end catch_common.h +namespace Catch { + + struct RegistrarForTagAliases { + RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo ); + }; + +} // end namespace Catch + +#define CATCH_REGISTER_TAG_ALIAS( alias, spec ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::RegistrarForTagAliases INTERNAL_CATCH_UNIQUE_NAME( AutoRegisterTagAlias )( alias, spec, CATCH_INTERNAL_LINEINFO ); } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + +// end catch_tag_alias_autoregistrar.h +// start catch_test_registry.h + +// start catch_interfaces_testcase.h + +#include <vector> + +namespace Catch { + + class TestSpec; + + struct ITestInvoker { + virtual void invoke () const = 0; + virtual ~ITestInvoker(); + }; + + class TestCase; + struct IConfig; + + struct ITestCaseRegistry { + virtual ~ITestCaseRegistry(); + virtual std::vector<TestCase> const& getAllTests() const = 0; + virtual std::vector<TestCase> const& getAllTestsSorted( IConfig const& config ) const = 0; + }; + + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ); + std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config ); + std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config ); + +} + +// end catch_interfaces_testcase.h +// start catch_stringref.h + +#include <cstddef> +#include <string> +#include <iosfwd> + +namespace Catch { + + /// A non-owning string class (similar to the forthcoming std::string_view) + /// Note that, because a StringRef may be a substring of another string, + /// it may not be null terminated. c_str() must return a null terminated + /// string, however, and so the StringRef will internally take ownership + /// (taking a copy), if necessary. In theory this ownership is not externally + /// visible - but it does mean (substring) StringRefs should not be shared between + /// threads. + class StringRef { + public: + using size_type = std::size_t; + + private: + friend struct StringRefTestAccess; + + char const* m_start; + size_type m_size; + + char* m_data = nullptr; + + void takeOwnership(); + + static constexpr char const* const s_empty = ""; + + public: // construction/ assignment + StringRef() noexcept + : StringRef( s_empty, 0 ) + {} + + StringRef( StringRef const& other ) noexcept + : m_start( other.m_start ), + m_size( other.m_size ) + {} + + StringRef( StringRef&& other ) noexcept + : m_start( other.m_start ), + m_size( other.m_size ), + m_data( other.m_data ) + { + other.m_data = nullptr; + } + + StringRef( char const* rawChars ) noexcept; + + StringRef( char const* rawChars, size_type size ) noexcept + : m_start( rawChars ), + m_size( size ) + {} + + StringRef( std::string const& stdString ) noexcept + : m_start( stdString.c_str() ), + m_size( stdString.size() ) + {} + + ~StringRef() noexcept { + delete[] m_data; + } + + auto operator = ( StringRef const &other ) noexcept -> StringRef& { + delete[] m_data; + m_data = nullptr; + m_start = other.m_start; + m_size = other.m_size; + return *this; + } + + operator std::string() const; + + void swap( StringRef& other ) noexcept; + + public: // operators + auto operator == ( StringRef const& other ) const noexcept -> bool; + auto operator != ( StringRef const& other ) const noexcept -> bool; + + auto operator[] ( size_type index ) const noexcept -> char; + + public: // named queries + auto empty() const noexcept -> bool { + return m_size == 0; + } + auto size() const noexcept -> size_type { + return m_size; + } + + auto numberOfCharacters() const noexcept -> size_type; + auto c_str() const -> char const*; + + public: // substrings and searches + auto substr( size_type start, size_type size ) const noexcept -> StringRef; + + // Returns the current start pointer. + // Note that the pointer can change when if the StringRef is a substring + auto currentData() const noexcept -> char const*; + + private: // ownership queries - may not be consistent between calls + auto isOwned() const noexcept -> bool; + auto isSubstring() const noexcept -> bool; + }; + + auto operator + ( StringRef const& lhs, StringRef const& rhs ) -> std::string; + auto operator + ( StringRef const& lhs, char const* rhs ) -> std::string; + auto operator + ( char const* lhs, StringRef const& rhs ) -> std::string; + + auto operator += ( std::string& lhs, StringRef const& sr ) -> std::string&; + auto operator << ( std::ostream& os, StringRef const& sr ) -> std::ostream&; + + inline auto operator "" _sr( char const* rawChars, std::size_t size ) noexcept -> StringRef { + return StringRef( rawChars, size ); + } + +} // namespace Catch + +inline auto operator "" _catch_sr( char const* rawChars, std::size_t size ) noexcept -> Catch::StringRef { + return Catch::StringRef( rawChars, size ); +} + +// end catch_stringref.h +// start catch_type_traits.hpp + + +#include <type_traits> + +namespace Catch{ + +#ifdef CATCH_CPP17_OR_GREATER + template <typename...> + inline constexpr auto is_unique = std::true_type{}; + + template <typename T, typename... Rest> + inline constexpr auto is_unique<T, Rest...> = std::bool_constant< + (!std::is_same_v<T, Rest> && ...) && is_unique<Rest...> + >{}; +#else + +template <typename...> +struct is_unique : std::true_type{}; + +template <typename T0, typename T1, typename... Rest> +struct is_unique<T0, T1, Rest...> : std::integral_constant +<bool, + !std::is_same<T0, T1>::value + && is_unique<T0, Rest...>::value + && is_unique<T1, Rest...>::value +>{}; + +#endif +} + +// end catch_type_traits.hpp +// start catch_preprocessor.hpp + + +#define CATCH_RECURSION_LEVEL0(...) __VA_ARGS__ +#define CATCH_RECURSION_LEVEL1(...) CATCH_RECURSION_LEVEL0(CATCH_RECURSION_LEVEL0(CATCH_RECURSION_LEVEL0(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL2(...) CATCH_RECURSION_LEVEL1(CATCH_RECURSION_LEVEL1(CATCH_RECURSION_LEVEL1(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL3(...) CATCH_RECURSION_LEVEL2(CATCH_RECURSION_LEVEL2(CATCH_RECURSION_LEVEL2(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL4(...) CATCH_RECURSION_LEVEL3(CATCH_RECURSION_LEVEL3(CATCH_RECURSION_LEVEL3(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL5(...) CATCH_RECURSION_LEVEL4(CATCH_RECURSION_LEVEL4(CATCH_RECURSION_LEVEL4(__VA_ARGS__))) + +#ifdef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define INTERNAL_CATCH_EXPAND_VARGS(...) __VA_ARGS__ +// MSVC needs more evaluations +#define CATCH_RECURSION_LEVEL6(...) CATCH_RECURSION_LEVEL5(CATCH_RECURSION_LEVEL5(CATCH_RECURSION_LEVEL5(__VA_ARGS__))) +#define CATCH_RECURSE(...) CATCH_RECURSION_LEVEL6(CATCH_RECURSION_LEVEL6(__VA_ARGS__)) +#else +#define CATCH_RECURSE(...) CATCH_RECURSION_LEVEL5(__VA_ARGS__) +#endif + +#define CATCH_REC_END(...) +#define CATCH_REC_OUT + +#define CATCH_EMPTY() +#define CATCH_DEFER(id) id CATCH_EMPTY() + +#define CATCH_REC_GET_END2() 0, CATCH_REC_END +#define CATCH_REC_GET_END1(...) CATCH_REC_GET_END2 +#define CATCH_REC_GET_END(...) CATCH_REC_GET_END1 +#define CATCH_REC_NEXT0(test, next, ...) next CATCH_REC_OUT +#define CATCH_REC_NEXT1(test, next) CATCH_DEFER ( CATCH_REC_NEXT0 ) ( test, next, 0) +#define CATCH_REC_NEXT(test, next) CATCH_REC_NEXT1(CATCH_REC_GET_END test, next) + +#define CATCH_REC_LIST0(f, x, peek, ...) , f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1) ) ( f, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST1(f, x, peek, ...) , f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST0) ) ( f, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST2(f, x, peek, ...) f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1) ) ( f, peek, __VA_ARGS__ ) + +#define CATCH_REC_LIST0_UD(f, userdata, x, peek, ...) , f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1_UD) ) ( f, userdata, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST1_UD(f, userdata, x, peek, ...) , f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST0_UD) ) ( f, userdata, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST2_UD(f, userdata, x, peek, ...) f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1_UD) ) ( f, userdata, peek, __VA_ARGS__ ) + +// Applies the function macro `f` to each of the remaining parameters, inserts commas between the results, +// and passes userdata as the first parameter to each invocation, +// e.g. CATCH_REC_LIST_UD(f, x, a, b, c) evaluates to f(x, a), f(x, b), f(x, c) +#define CATCH_REC_LIST_UD(f, userdata, ...) CATCH_RECURSE(CATCH_REC_LIST2_UD(f, userdata, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) + +#define CATCH_REC_LIST(f, ...) CATCH_RECURSE(CATCH_REC_LIST2(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) + +#define INTERNAL_CATCH_EXPAND1(param) INTERNAL_CATCH_EXPAND2(param) +#define INTERNAL_CATCH_EXPAND2(...) INTERNAL_CATCH_NO## __VA_ARGS__ +#define INTERNAL_CATCH_DEF(...) INTERNAL_CATCH_DEF __VA_ARGS__ +#define INTERNAL_CATCH_NOINTERNAL_CATCH_DEF +#define INTERNAL_CATCH_STRINGIZE(...) INTERNAL_CATCH_STRINGIZE2(__VA_ARGS__) +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define INTERNAL_CATCH_STRINGIZE2(...) #__VA_ARGS__ +#define INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS(param) INTERNAL_CATCH_STRINGIZE(INTERNAL_CATCH_REMOVE_PARENS(param)) +#else +// MSVC is adding extra space and needs another indirection to expand INTERNAL_CATCH_NOINTERNAL_CATCH_DEF +#define INTERNAL_CATCH_STRINGIZE2(...) INTERNAL_CATCH_STRINGIZE3(__VA_ARGS__) +#define INTERNAL_CATCH_STRINGIZE3(...) #__VA_ARGS__ +#define INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS(param) (INTERNAL_CATCH_STRINGIZE(INTERNAL_CATCH_REMOVE_PARENS(param)) + 1) +#endif + +#define INTERNAL_CATCH_REMOVE_PARENS(...) INTERNAL_CATCH_EXPAND1(INTERNAL_CATCH_DEF __VA_ARGS__) + +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME2(Name, ...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME3(Name, __VA_ARGS__) +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME3(Name,...) Name " - " #__VA_ARGS__ +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME(Name,...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME2(Name, INTERNAL_CATCH_REMOVE_PARENS(__VA_ARGS__)) +#else +// MSVC is adding extra space and needs more calls to properly remove () +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME3(Name,...) Name " -" #__VA_ARGS__ +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME1(Name, ...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME2(Name, __VA_ARGS__) +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME(Name, ...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME1(Name, INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_REMOVE_PARENS(__VA_ARGS__))) +#endif + +#define INTERNAL_CATCH_MAKE_TYPE_LIST(types) Catch::TypeList<INTERNAL_CATCH_REMOVE_PARENS(types)> + +#define INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(types)\ + CATCH_REC_LIST(INTERNAL_CATCH_MAKE_TYPE_LIST,INTERNAL_CATCH_REMOVE_PARENS(types)) + +// end catch_preprocessor.hpp +// start catch_meta.hpp + + +#include <type_traits> + +namespace Catch { +template< typename... > +struct TypeList {}; + +template< typename... > +struct append; + +template< template<typename...> class L1 + , typename...E1 + , template<typename...> class L2 + , typename...E2 +> +struct append< L1<E1...>, L2<E2...> > { + using type = L1<E1..., E2...>; +}; + +template< template<typename...> class L1 + , typename...E1 + , template<typename...> class L2 + , typename...E2 + , typename...Rest +> +struct append< L1<E1...>, L2<E2...>, Rest...> { + using type = typename append< L1<E1..., E2...>, Rest... >::type; +}; + +template< template<typename...> class + , typename... +> +struct rewrap; + +template< template<typename...> class Container + , template<typename...> class List + , typename...elems +> +struct rewrap<Container, List<elems...>> { + using type = TypeList< Container< elems... > >; +}; + +template< template<typename...> class Container + , template<typename...> class List + , class...Elems + , typename...Elements> + struct rewrap<Container, List<Elems...>, Elements...> { + using type = typename append<TypeList<Container<Elems...>>, typename rewrap<Container, Elements...>::type>::type; +}; + +template< template<typename...> class...Containers > +struct combine { + template< typename...Types > + struct with_types { + template< template <typename...> class Final > + struct into { + using type = typename append<Final<>, typename rewrap<Containers, Types...>::type...>::type; + }; + }; +}; + +template<typename T> +struct always_false : std::false_type {}; + +} // namespace Catch + +// end catch_meta.hpp +namespace Catch { + +template<typename C> +class TestInvokerAsMethod : public ITestInvoker { + void (C::*m_testAsMethod)(); +public: + TestInvokerAsMethod( void (C::*testAsMethod)() ) noexcept : m_testAsMethod( testAsMethod ) {} + + void invoke() const override { + C obj; + (obj.*m_testAsMethod)(); + } +}; + +auto makeTestInvoker( void(*testAsFunction)() ) noexcept -> ITestInvoker*; + +template<typename C> +auto makeTestInvoker( void (C::*testAsMethod)() ) noexcept -> ITestInvoker* { + return new(std::nothrow) TestInvokerAsMethod<C>( testAsMethod ); +} + +struct NameAndTags { + NameAndTags( StringRef const& name_ = StringRef(), StringRef const& tags_ = StringRef() ) noexcept; + StringRef name; + StringRef tags; +}; + +struct AutoReg : NonCopyable { + AutoReg( ITestInvoker* invoker, SourceLineInfo const& lineInfo, StringRef const& classOrMethod, NameAndTags const& nameAndTags ) noexcept; + ~AutoReg(); +}; + +} // end namespace Catch + +#if defined(CATCH_CONFIG_DISABLE) + #define INTERNAL_CATCH_TESTCASE_NO_REGISTRATION( TestName, ... ) \ + static void TestName() + #define INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION( TestName, ClassName, ... ) \ + namespace{ \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName) { \ + void test(); \ + }; \ + } \ + void TestName::test() + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION( TestName, ... ) \ + template<typename TestType> \ + static void TestName() + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION( TestName, ClassName, ... ) \ + namespace{ \ + template<typename TestType> \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName <TestType>) { \ + void test(); \ + }; \ + } \ + template<typename TestType> \ + void TestName::test() +#endif + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TESTCASE2( TestName, ... ) \ + static void TestName(); \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( &TestName ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ __VA_ARGS__ } ); } /* NOLINT */ \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + static void TestName() + #define INTERNAL_CATCH_TESTCASE( ... ) \ + INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), __VA_ARGS__ ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, ... ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( &QualifiedMethod ), CATCH_INTERNAL_LINEINFO, "&" #QualifiedMethod, Catch::NameAndTags{ __VA_ARGS__ } ); } /* NOLINT */ \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEST_CASE_METHOD2( TestName, ClassName, ... )\ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName) { \ + void test(); \ + }; \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( Catch::makeTestInvoker( &TestName::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ __VA_ARGS__ } ); /* NOLINT */ \ + } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + void TestName::test() + #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, ... ) \ + INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, __VA_ARGS__ ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_REGISTER_TESTCASE( Function, ... ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( Function ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ __VA_ARGS__ } ); /* NOLINT */ \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_2(TestName, TestFunc, Name, Tags, ... )\ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + template<typename TestType> \ + static void TestFunc();\ + namespace {\ + template<typename...Types> \ + struct TestName{\ + template<typename...Ts> \ + TestName(Ts...names){\ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(CATCH_REC_LIST(INTERNAL_CATCH_REMOVE_PARENS, __VA_ARGS__)) \ + using expander = int[];\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestFunc<Types> ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ names, Tags } ), 0)... };/* NOLINT */ \ + }\ + };\ + INTERNAL_CATCH_TEMPLATE_REGISTRY_INITIATE(TestName, Name, __VA_ARGS__) \ + }\ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + template<typename TestType> \ + static void TestFunc() + +#if defined(CATCH_CPP17_OR_GREATER) +#define CATCH_INTERNAL_CHECK_UNIQUE_TYPES(...) static_assert(Catch::is_unique<__VA_ARGS__>,"Duplicate type detected in declaration of template test case"); +#else +#define CATCH_INTERNAL_CHECK_UNIQUE_TYPES(...) static_assert(Catch::is_unique<__VA_ARGS__>::value,"Duplicate type detected in declaration of template test case"); +#endif + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \ + INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, __VA_ARGS__ ) +#else + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, __VA_ARGS__ ) ) +#endif + + #define INTERNAL_CATCH_TEMPLATE_REGISTRY_INITIATE(TestName, Name, ...)\ + static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\ + TestName<CATCH_REC_LIST(INTERNAL_CATCH_REMOVE_PARENS, __VA_ARGS__)>(CATCH_REC_LIST_UD(INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME,Name, __VA_ARGS__));\ + return 0;\ + }(); + + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(TestName, TestFuncName, Name, Tags, TmplTypes, TypesList) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + template<typename TestType> static void TestFuncName(); \ + namespace { \ + template<typename... Types> \ + struct TestName { \ + TestName() { \ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(Types...) \ + int index = 0; \ + using expander = int[]; \ + constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TmplTypes))};\ + constexpr char const* types_list[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TypesList))};\ + constexpr auto num_types = sizeof(types_list) / sizeof(types_list[0]);\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestFuncName<Types> ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index / num_types]) + "<" + std::string(types_list[index % num_types]) + ">", Tags } ), index++, 0)... };/* NOLINT */\ + } \ + }; \ + static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){ \ + using TestInit = Catch::combine<INTERNAL_CATCH_REMOVE_PARENS(TmplTypes)> \ + ::with_types<INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(TypesList)>::into<TestName>::type; \ + TestInit(); \ + return 0; \ + }(); \ + } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + template<typename TestType> \ + static void TestFuncName() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\ + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ),Name,Tags,__VA_ARGS__) +#else + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, __VA_ARGS__ ) ) +#endif + + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( TestNameClass, TestName, ClassName, Name, Tags, ... ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ \ + template<typename TestType> \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName <TestType>) { \ + void test();\ + };\ + template<typename...Types> \ + struct TestNameClass{\ + template<typename...Ts> \ + TestNameClass(Ts...names){\ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(CATCH_REC_LIST(INTERNAL_CATCH_REMOVE_PARENS, __VA_ARGS__)) \ + using expander = int[];\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestName<Types>::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ names, Tags } ), 0)... };/* NOLINT */ \ + }\ + };\ + INTERNAL_CATCH_TEMPLATE_REGISTRY_INITIATE(TestNameClass, Name, __VA_ARGS__)\ + }\ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS\ + template<typename TestType> \ + void TestName<TestType>::test() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \ + INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, __VA_ARGS__ ) +#else + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, __VA_ARGS__ ) ) +#endif + + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2(TestNameClass, TestName, ClassName, Name, Tags, TmplTypes, TypesList)\ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + template<typename TestType> \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName <TestType>) { \ + void test();\ + };\ + namespace {\ + template<typename...Types>\ + struct TestNameClass{\ + TestNameClass(){\ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(Types...)\ + int index = 0;\ + using expander = int[];\ + constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TmplTypes))};\ + constexpr char const* types_list[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TypesList))};\ + constexpr auto num_types = sizeof(types_list) / sizeof(types_list[0]);\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestName<Types>::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index / num_types]) + "<" + std::string(types_list[index % num_types]) + ">", Tags } ), index++, 0)... };/* NOLINT */ \ + }\ + };\ + static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\ + using TestInit = Catch::combine<INTERNAL_CATCH_REMOVE_PARENS(TmplTypes)>\ + ::with_types<INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(TypesList)>::into<TestNameClass>::type;\ + TestInit();\ + return 0;\ + }(); \ + }\ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + template<typename TestType> \ + void TestName<TestType>::test() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\ + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, __VA_ARGS__ ) +#else + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, __VA_ARGS__ ) ) +#endif + +// end catch_test_registry.h +// start catch_capture.hpp + +// start catch_assertionhandler.h + +// start catch_assertioninfo.h + +// start catch_result_type.h + +namespace Catch { + + // ResultWas::OfType enum + struct ResultWas { enum OfType { + Unknown = -1, + Ok = 0, + Info = 1, + Warning = 2, + + FailureBit = 0x10, + + ExpressionFailed = FailureBit | 1, + ExplicitFailure = FailureBit | 2, + + Exception = 0x100 | FailureBit, + + ThrewException = Exception | 1, + DidntThrowException = Exception | 2, + + FatalErrorCondition = 0x200 | FailureBit + + }; }; + + bool isOk( ResultWas::OfType resultType ); + bool isJustInfo( int flags ); + + // ResultDisposition::Flags enum + struct ResultDisposition { enum Flags { + Normal = 0x01, + + ContinueOnFailure = 0x02, // Failures fail test, but execution continues + FalseTest = 0x04, // Prefix expression with ! + SuppressFail = 0x08 // Failures are reported but do not fail the test + }; }; + + ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ); + + bool shouldContinueOnFailure( int flags ); + inline bool isFalseTest( int flags ) { return ( flags & ResultDisposition::FalseTest ) != 0; } + bool shouldSuppressFailure( int flags ); + +} // end namespace Catch + +// end catch_result_type.h +namespace Catch { + + struct AssertionInfo + { + StringRef macroName; + SourceLineInfo lineInfo; + StringRef capturedExpression; + ResultDisposition::Flags resultDisposition; + + // We want to delete this constructor but a compiler bug in 4.8 means + // the struct is then treated as non-aggregate + //AssertionInfo() = delete; + }; + +} // end namespace Catch + +// end catch_assertioninfo.h +// start catch_decomposer.h + +// start catch_tostring.h + +#include <vector> +#include <cstddef> +#include <type_traits> +#include <string> +// start catch_stream.h + +#include <iosfwd> +#include <cstddef> +#include <ostream> + +namespace Catch { + + std::ostream& cout(); + std::ostream& cerr(); + std::ostream& clog(); + + class StringRef; + + struct IStream { + virtual ~IStream(); + virtual std::ostream& stream() const = 0; + }; + + auto makeStream( StringRef const &filename ) -> IStream const*; + + class ReusableStringStream { + std::size_t m_index; + std::ostream* m_oss; + public: + ReusableStringStream(); + ~ReusableStringStream(); + + auto str() const -> std::string; + + template<typename T> + auto operator << ( T const& value ) -> ReusableStringStream& { + *m_oss << value; + return *this; + } + auto get() -> std::ostream& { return *m_oss; } + }; +} + +// end catch_stream.h + +#ifdef CATCH_CONFIG_CPP17_STRING_VIEW +#include <string_view> +#endif + +#ifdef __OBJC__ +// start catch_objc_arc.hpp + +#import <Foundation/Foundation.h> + +#ifdef __has_feature +#define CATCH_ARC_ENABLED __has_feature(objc_arc) +#else +#define CATCH_ARC_ENABLED 0 +#endif + +void arcSafeRelease( NSObject* obj ); +id performOptionalSelector( id obj, SEL sel ); + +#if !CATCH_ARC_ENABLED +inline void arcSafeRelease( NSObject* obj ) { + [obj release]; +} +inline id performOptionalSelector( id obj, SEL sel ) { + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; + return nil; +} +#define CATCH_UNSAFE_UNRETAINED +#define CATCH_ARC_STRONG +#else +inline void arcSafeRelease( NSObject* ){} +inline id performOptionalSelector( id obj, SEL sel ) { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Warc-performSelector-leaks" +#endif + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + return nil; +} +#define CATCH_UNSAFE_UNRETAINED __unsafe_unretained +#define CATCH_ARC_STRONG __strong +#endif + +// end catch_objc_arc.hpp +#endif + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable:4180) // We attempt to stream a function (address) by const&, which MSVC complains about but is harmless +#endif + +namespace Catch { + namespace Detail { + + extern const std::string unprintableString; + + std::string rawMemoryToString( const void *object, std::size_t size ); + + template<typename T> + std::string rawMemoryToString( const T& object ) { + return rawMemoryToString( &object, sizeof(object) ); + } + + template<typename T> + class IsStreamInsertable { + template<typename SS, typename TT> + static auto test(int) + -> decltype(std::declval<SS&>() << std::declval<TT>(), std::true_type()); + + template<typename, typename> + static auto test(...)->std::false_type; + + public: + static const bool value = decltype(test<std::ostream, const T&>(0))::value; + }; + + template<typename E> + std::string convertUnknownEnumToString( E e ); + + template<typename T> + typename std::enable_if< + !std::is_enum<T>::value && !std::is_base_of<std::exception, T>::value, + std::string>::type convertUnstreamable( T const& ) { + return Detail::unprintableString; + } + template<typename T> + typename std::enable_if< + !std::is_enum<T>::value && std::is_base_of<std::exception, T>::value, + std::string>::type convertUnstreamable(T const& ex) { + return ex.what(); + } + + template<typename T> + typename std::enable_if< + std::is_enum<T>::value + , std::string>::type convertUnstreamable( T const& value ) { + return convertUnknownEnumToString( value ); + } + +#if defined(_MANAGED) + //! Convert a CLR string to a utf8 std::string + template<typename T> + std::string clrReferenceToString( T^ ref ) { + if (ref == nullptr) + return std::string("null"); + auto bytes = System::Text::Encoding::UTF8->GetBytes(ref->ToString()); + cli::pin_ptr<System::Byte> p = &bytes[0]; + return std::string(reinterpret_cast<char const *>(p), bytes->Length); + } +#endif + + } // namespace Detail + + // If we decide for C++14, change these to enable_if_ts + template <typename T, typename = void> + struct StringMaker { + template <typename Fake = T> + static + typename std::enable_if<::Catch::Detail::IsStreamInsertable<Fake>::value, std::string>::type + convert(const Fake& value) { + ReusableStringStream rss; + // NB: call using the function-like syntax to avoid ambiguity with + // user-defined templated operator<< under clang. + rss.operator<<(value); + return rss.str(); + } + + template <typename Fake = T> + static + typename std::enable_if<!::Catch::Detail::IsStreamInsertable<Fake>::value, std::string>::type + convert( const Fake& value ) { +#if !defined(CATCH_CONFIG_FALLBACK_STRINGIFIER) + return Detail::convertUnstreamable(value); +#else + return CATCH_CONFIG_FALLBACK_STRINGIFIER(value); +#endif + } + }; + + namespace Detail { + + // This function dispatches all stringification requests inside of Catch. + // Should be preferably called fully qualified, like ::Catch::Detail::stringify + template <typename T> + std::string stringify(const T& e) { + return ::Catch::StringMaker<typename std::remove_cv<typename std::remove_reference<T>::type>::type>::convert(e); + } + + template<typename E> + std::string convertUnknownEnumToString( E e ) { + return ::Catch::Detail::stringify(static_cast<typename std::underlying_type<E>::type>(e)); + } + +#if defined(_MANAGED) + template <typename T> + std::string stringify( T^ e ) { + return ::Catch::StringMaker<T^>::convert(e); + } +#endif + + } // namespace Detail + + // Some predefined specializations + + template<> + struct StringMaker<std::string> { + static std::string convert(const std::string& str); + }; + +#ifdef CATCH_CONFIG_CPP17_STRING_VIEW + template<> + struct StringMaker<std::string_view> { + static std::string convert(std::string_view str); + }; +#endif + + template<> + struct StringMaker<char const *> { + static std::string convert(char const * str); + }; + template<> + struct StringMaker<char *> { + static std::string convert(char * str); + }; + +#ifdef CATCH_CONFIG_WCHAR + template<> + struct StringMaker<std::wstring> { + static std::string convert(const std::wstring& wstr); + }; + +# ifdef CATCH_CONFIG_CPP17_STRING_VIEW + template<> + struct StringMaker<std::wstring_view> { + static std::string convert(std::wstring_view str); + }; +# endif + + template<> + struct StringMaker<wchar_t const *> { + static std::string convert(wchar_t const * str); + }; + template<> + struct StringMaker<wchar_t *> { + static std::string convert(wchar_t * str); + }; +#endif + + // TBD: Should we use `strnlen` to ensure that we don't go out of the buffer, + // while keeping string semantics? + template<int SZ> + struct StringMaker<char[SZ]> { + static std::string convert(char const* str) { + return ::Catch::Detail::stringify(std::string{ str }); + } + }; + template<int SZ> + struct StringMaker<signed char[SZ]> { + static std::string convert(signed char const* str) { + return ::Catch::Detail::stringify(std::string{ reinterpret_cast<char const *>(str) }); + } + }; + template<int SZ> + struct StringMaker<unsigned char[SZ]> { + static std::string convert(unsigned char const* str) { + return ::Catch::Detail::stringify(std::string{ reinterpret_cast<char const *>(str) }); + } + }; + + template<> + struct StringMaker<int> { + static std::string convert(int value); + }; + template<> + struct StringMaker<long> { + static std::string convert(long value); + }; + template<> + struct StringMaker<long long> { + static std::string convert(long long value); + }; + template<> + struct StringMaker<unsigned int> { + static std::string convert(unsigned int value); + }; + template<> + struct StringMaker<unsigned long> { + static std::string convert(unsigned long value); + }; + template<> + struct StringMaker<unsigned long long> { + static std::string convert(unsigned long long value); + }; + + template<> + struct StringMaker<bool> { + static std::string convert(bool b); + }; + + template<> + struct StringMaker<char> { + static std::string convert(char c); + }; + template<> + struct StringMaker<signed char> { + static std::string convert(signed char c); + }; + template<> + struct StringMaker<unsigned char> { + static std::string convert(unsigned char c); + }; + + template<> + struct StringMaker<std::nullptr_t> { + static std::string convert(std::nullptr_t); + }; + + template<> + struct StringMaker<float> { + static std::string convert(float value); + }; + template<> + struct StringMaker<double> { + static std::string convert(double value); + }; + + template <typename T> + struct StringMaker<T*> { + template <typename U> + static std::string convert(U* p) { + if (p) { + return ::Catch::Detail::rawMemoryToString(p); + } else { + return "nullptr"; + } + } + }; + + template <typename R, typename C> + struct StringMaker<R C::*> { + static std::string convert(R C::* p) { + if (p) { + return ::Catch::Detail::rawMemoryToString(p); + } else { + return "nullptr"; + } + } + }; + +#if defined(_MANAGED) + template <typename T> + struct StringMaker<T^> { + static std::string convert( T^ ref ) { + return ::Catch::Detail::clrReferenceToString(ref); + } + }; +#endif + + namespace Detail { + template<typename InputIterator> + std::string rangeToString(InputIterator first, InputIterator last) { + ReusableStringStream rss; + rss << "{ "; + if (first != last) { + rss << ::Catch::Detail::stringify(*first); + for (++first; first != last; ++first) + rss << ", " << ::Catch::Detail::stringify(*first); + } + rss << " }"; + return rss.str(); + } + } + +#ifdef __OBJC__ + template<> + struct StringMaker<NSString*> { + static std::string convert(NSString * nsstring) { + if (!nsstring) + return "nil"; + return std::string("@") + [nsstring UTF8String]; + } + }; + template<> + struct StringMaker<NSObject*> { + static std::string convert(NSObject* nsObject) { + return ::Catch::Detail::stringify([nsObject description]); + } + + }; + namespace Detail { + inline std::string stringify( NSString* nsstring ) { + return StringMaker<NSString*>::convert( nsstring ); + } + + } // namespace Detail +#endif // __OBJC__ + +} // namespace Catch + +////////////////////////////////////////////////////// +// Separate std-lib types stringification, so it can be selectively enabled +// This means that we do not bring in + +#if defined(CATCH_CONFIG_ENABLE_ALL_STRINGMAKERS) +# define CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER +# define CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER +# define CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER +# define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER +# define CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER +#endif + +// Separate std::pair specialization +#if defined(CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER) +#include <utility> +namespace Catch { + template<typename T1, typename T2> + struct StringMaker<std::pair<T1, T2> > { + static std::string convert(const std::pair<T1, T2>& pair) { + ReusableStringStream rss; + rss << "{ " + << ::Catch::Detail::stringify(pair.first) + << ", " + << ::Catch::Detail::stringify(pair.second) + << " }"; + return rss.str(); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER + +#if defined(CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER) && defined(CATCH_CONFIG_CPP17_OPTIONAL) +#include <optional> +namespace Catch { + template<typename T> + struct StringMaker<std::optional<T> > { + static std::string convert(const std::optional<T>& optional) { + ReusableStringStream rss; + if (optional.has_value()) { + rss << ::Catch::Detail::stringify(*optional); + } else { + rss << "{ }"; + } + return rss.str(); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER + +// Separate std::tuple specialization +#if defined(CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER) +#include <tuple> +namespace Catch { + namespace Detail { + template< + typename Tuple, + std::size_t N = 0, + bool = (N < std::tuple_size<Tuple>::value) + > + struct TupleElementPrinter { + static void print(const Tuple& tuple, std::ostream& os) { + os << (N ? ", " : " ") + << ::Catch::Detail::stringify(std::get<N>(tuple)); + TupleElementPrinter<Tuple, N + 1>::print(tuple, os); + } + }; + + template< + typename Tuple, + std::size_t N + > + struct TupleElementPrinter<Tuple, N, false> { + static void print(const Tuple&, std::ostream&) {} + }; + + } + + template<typename ...Types> + struct StringMaker<std::tuple<Types...>> { + static std::string convert(const std::tuple<Types...>& tuple) { + ReusableStringStream rss; + rss << '{'; + Detail::TupleElementPrinter<std::tuple<Types...>>::print(tuple, rss.get()); + rss << " }"; + return rss.str(); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER + +#if defined(CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER) && defined(CATCH_CONFIG_CPP17_VARIANT) +#include <variant> +namespace Catch { + template<> + struct StringMaker<std::monostate> { + static std::string convert(const std::monostate&) { + return "{ }"; + } + }; + + template<typename... Elements> + struct StringMaker<std::variant<Elements...>> { + static std::string convert(const std::variant<Elements...>& variant) { + if (variant.valueless_by_exception()) { + return "{valueless variant}"; + } else { + return std::visit( + [](const auto& value) { + return ::Catch::Detail::stringify(value); + }, + variant + ); + } + } + }; +} +#endif // CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER + +namespace Catch { + struct not_this_one {}; // Tag type for detecting which begin/ end are being selected + + // Import begin/ end from std here so they are considered alongside the fallback (...) overloads in this namespace + using std::begin; + using std::end; + + not_this_one begin( ... ); + not_this_one end( ... ); + + template <typename T> + struct is_range { + static const bool value = + !std::is_same<decltype(begin(std::declval<T>())), not_this_one>::value && + !std::is_same<decltype(end(std::declval<T>())), not_this_one>::value; + }; + +#if defined(_MANAGED) // Managed types are never ranges + template <typename T> + struct is_range<T^> { + static const bool value = false; + }; +#endif + + template<typename Range> + std::string rangeToString( Range const& range ) { + return ::Catch::Detail::rangeToString( begin( range ), end( range ) ); + } + + // Handle vector<bool> specially + template<typename Allocator> + std::string rangeToString( std::vector<bool, Allocator> const& v ) { + ReusableStringStream rss; + rss << "{ "; + bool first = true; + for( bool b : v ) { + if( first ) + first = false; + else + rss << ", "; + rss << ::Catch::Detail::stringify( b ); + } + rss << " }"; + return rss.str(); + } + + template<typename R> + struct StringMaker<R, typename std::enable_if<is_range<R>::value && !::Catch::Detail::IsStreamInsertable<R>::value>::type> { + static std::string convert( R const& range ) { + return rangeToString( range ); + } + }; + + template <typename T, int SZ> + struct StringMaker<T[SZ]> { + static std::string convert(T const(&arr)[SZ]) { + return rangeToString(arr); + } + }; + +} // namespace Catch + +// Separate std::chrono::duration specialization +#if defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER) +#include <ctime> +#include <ratio> +#include <chrono> + +namespace Catch { + +template <class Ratio> +struct ratio_string { + static std::string symbol(); +}; + +template <class Ratio> +std::string ratio_string<Ratio>::symbol() { + Catch::ReusableStringStream rss; + rss << '[' << Ratio::num << '/' + << Ratio::den << ']'; + return rss.str(); +} +template <> +struct ratio_string<std::atto> { + static std::string symbol(); +}; +template <> +struct ratio_string<std::femto> { + static std::string symbol(); +}; +template <> +struct ratio_string<std::pico> { + static std::string symbol(); +}; +template <> +struct ratio_string<std::nano> { + static std::string symbol(); +}; +template <> +struct ratio_string<std::micro> { + static std::string symbol(); +}; +template <> +struct ratio_string<std::milli> { + static std::string symbol(); +}; + + //////////// + // std::chrono::duration specializations + template<typename Value, typename Ratio> + struct StringMaker<std::chrono::duration<Value, Ratio>> { + static std::string convert(std::chrono::duration<Value, Ratio> const& duration) { + ReusableStringStream rss; + rss << duration.count() << ' ' << ratio_string<Ratio>::symbol() << 's'; + return rss.str(); + } + }; + template<typename Value> + struct StringMaker<std::chrono::duration<Value, std::ratio<1>>> { + static std::string convert(std::chrono::duration<Value, std::ratio<1>> const& duration) { + ReusableStringStream rss; + rss << duration.count() << " s"; + return rss.str(); + } + }; + template<typename Value> + struct StringMaker<std::chrono::duration<Value, std::ratio<60>>> { + static std::string convert(std::chrono::duration<Value, std::ratio<60>> const& duration) { + ReusableStringStream rss; + rss << duration.count() << " m"; + return rss.str(); + } + }; + template<typename Value> + struct StringMaker<std::chrono::duration<Value, std::ratio<3600>>> { + static std::string convert(std::chrono::duration<Value, std::ratio<3600>> const& duration) { + ReusableStringStream rss; + rss << duration.count() << " h"; + return rss.str(); + } + }; + + //////////// + // std::chrono::time_point specialization + // Generic time_point cannot be specialized, only std::chrono::time_point<system_clock> + template<typename Clock, typename Duration> + struct StringMaker<std::chrono::time_point<Clock, Duration>> { + static std::string convert(std::chrono::time_point<Clock, Duration> const& time_point) { + return ::Catch::Detail::stringify(time_point.time_since_epoch()) + " since epoch"; + } + }; + // std::chrono::time_point<system_clock> specialization + template<typename Duration> + struct StringMaker<std::chrono::time_point<std::chrono::system_clock, Duration>> { + static std::string convert(std::chrono::time_point<std::chrono::system_clock, Duration> const& time_point) { + auto converted = std::chrono::system_clock::to_time_t(time_point); + +#ifdef _MSC_VER + std::tm timeInfo = {}; + gmtime_s(&timeInfo, &converted); +#else + std::tm* timeInfo = std::gmtime(&converted); +#endif + + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + char timeStamp[timeStampSize]; + const char * const fmt = "%Y-%m-%dT%H:%M:%SZ"; + +#ifdef _MSC_VER + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); +#else + std::strftime(timeStamp, timeStampSize, fmt, timeInfo); +#endif + return std::string(timeStamp); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// end catch_tostring.h +#include <iosfwd> + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable:4389) // '==' : signed/unsigned mismatch +#pragma warning(disable:4018) // more "signed/unsigned mismatch" +#pragma warning(disable:4312) // Converting int to T* using reinterpret_cast (issue on x64 platform) +#pragma warning(disable:4180) // qualifier applied to function type has no meaning +#pragma warning(disable:4800) // Forcing result to true or false +#endif + +namespace Catch { + + struct ITransientExpression { + auto isBinaryExpression() const -> bool { return m_isBinaryExpression; } + auto getResult() const -> bool { return m_result; } + virtual void streamReconstructedExpression( std::ostream &os ) const = 0; + + ITransientExpression( bool isBinaryExpression, bool result ) + : m_isBinaryExpression( isBinaryExpression ), + m_result( result ) + {} + + // We don't actually need a virtual destructor, but many static analysers + // complain if it's not here :-( + virtual ~ITransientExpression(); + + bool m_isBinaryExpression; + bool m_result; + + }; + + void formatReconstructedExpression( std::ostream &os, std::string const& lhs, StringRef op, std::string const& rhs ); + + template<typename LhsT, typename RhsT> + class BinaryExpr : public ITransientExpression { + LhsT m_lhs; + StringRef m_op; + RhsT m_rhs; + + void streamReconstructedExpression( std::ostream &os ) const override { + formatReconstructedExpression + ( os, Catch::Detail::stringify( m_lhs ), m_op, Catch::Detail::stringify( m_rhs ) ); + } + + public: + BinaryExpr( bool comparisonResult, LhsT lhs, StringRef op, RhsT rhs ) + : ITransientExpression{ true, comparisonResult }, + m_lhs( lhs ), + m_op( op ), + m_rhs( rhs ) + {} + + template<typename T> + auto operator && ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator || ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator == ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator != ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator > ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator < ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator >= ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename T> + auto operator <= ( T ) const -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<T>::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + }; + + template<typename LhsT> + class UnaryExpr : public ITransientExpression { + LhsT m_lhs; + + void streamReconstructedExpression( std::ostream &os ) const override { + os << Catch::Detail::stringify( m_lhs ); + } + + public: + explicit UnaryExpr( LhsT lhs ) + : ITransientExpression{ false, static_cast<bool>(lhs) }, + m_lhs( lhs ) + {} + }; + + // Specialised comparison functions to handle equality comparisons between ints and pointers (NULL deduces as an int) + template<typename LhsT, typename RhsT> + auto compareEqual( LhsT const& lhs, RhsT const& rhs ) -> bool { return static_cast<bool>(lhs == rhs); } + template<typename T> + auto compareEqual( T* const& lhs, int rhs ) -> bool { return lhs == reinterpret_cast<void const*>( rhs ); } + template<typename T> + auto compareEqual( T* const& lhs, long rhs ) -> bool { return lhs == reinterpret_cast<void const*>( rhs ); } + template<typename T> + auto compareEqual( int lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) == rhs; } + template<typename T> + auto compareEqual( long lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) == rhs; } + + template<typename LhsT, typename RhsT> + auto compareNotEqual( LhsT const& lhs, RhsT&& rhs ) -> bool { return static_cast<bool>(lhs != rhs); } + template<typename T> + auto compareNotEqual( T* const& lhs, int rhs ) -> bool { return lhs != reinterpret_cast<void const*>( rhs ); } + template<typename T> + auto compareNotEqual( T* const& lhs, long rhs ) -> bool { return lhs != reinterpret_cast<void const*>( rhs ); } + template<typename T> + auto compareNotEqual( int lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) != rhs; } + template<typename T> + auto compareNotEqual( long lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) != rhs; } + + template<typename LhsT> + class ExprLhs { + LhsT m_lhs; + public: + explicit ExprLhs( LhsT lhs ) : m_lhs( lhs ) {} + + template<typename RhsT> + auto operator == ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const { + return { compareEqual( m_lhs, rhs ), m_lhs, "==", rhs }; + } + auto operator == ( bool rhs ) -> BinaryExpr<LhsT, bool> const { + return { m_lhs == rhs, m_lhs, "==", rhs }; + } + + template<typename RhsT> + auto operator != ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const { + return { compareNotEqual( m_lhs, rhs ), m_lhs, "!=", rhs }; + } + auto operator != ( bool rhs ) -> BinaryExpr<LhsT, bool> const { + return { m_lhs != rhs, m_lhs, "!=", rhs }; + } + + template<typename RhsT> + auto operator > ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const { + return { static_cast<bool>(m_lhs > rhs), m_lhs, ">", rhs }; + } + template<typename RhsT> + auto operator < ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const { + return { static_cast<bool>(m_lhs < rhs), m_lhs, "<", rhs }; + } + template<typename RhsT> + auto operator >= ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const { + return { static_cast<bool>(m_lhs >= rhs), m_lhs, ">=", rhs }; + } + template<typename RhsT> + auto operator <= ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const { + return { static_cast<bool>(m_lhs <= rhs), m_lhs, "<=", rhs }; + } + + template<typename RhsT> + auto operator && ( RhsT const& ) -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<RhsT>::value, + "operator&& is not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template<typename RhsT> + auto operator || ( RhsT const& ) -> BinaryExpr<LhsT, RhsT const&> const { + static_assert(always_false<RhsT>::value, + "operator|| is not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + auto makeUnaryExpr() const -> UnaryExpr<LhsT> { + return UnaryExpr<LhsT>{ m_lhs }; + } + }; + + void handleExpression( ITransientExpression const& expr ); + + template<typename T> + void handleExpression( ExprLhs<T> const& expr ) { + handleExpression( expr.makeUnaryExpr() ); + } + + struct Decomposer { + template<typename T> + auto operator <= ( T const& lhs ) -> ExprLhs<T const&> { + return ExprLhs<T const&>{ lhs }; + } + + auto operator <=( bool value ) -> ExprLhs<bool> { + return ExprLhs<bool>{ value }; + } + }; + +} // end namespace Catch + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// end catch_decomposer.h +// start catch_interfaces_capture.h + +#include <string> + +namespace Catch { + + class AssertionResult; + struct AssertionInfo; + struct SectionInfo; + struct SectionEndInfo; + struct MessageInfo; + struct MessageBuilder; + struct Counts; + struct BenchmarkInfo; + struct BenchmarkStats; + struct AssertionReaction; + struct SourceLineInfo; + + struct ITransientExpression; + struct IGeneratorTracker; + + struct IResultCapture { + + virtual ~IResultCapture(); + + virtual bool sectionStarted( SectionInfo const& sectionInfo, + Counts& assertions ) = 0; + virtual void sectionEnded( SectionEndInfo const& endInfo ) = 0; + virtual void sectionEndedEarly( SectionEndInfo const& endInfo ) = 0; + + virtual auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& = 0; + + virtual void benchmarkStarting( BenchmarkInfo const& info ) = 0; + virtual void benchmarkEnded( BenchmarkStats const& stats ) = 0; + + virtual void pushScopedMessage( MessageInfo const& message ) = 0; + virtual void popScopedMessage( MessageInfo const& message ) = 0; + + virtual void emplaceUnscopedMessage( MessageBuilder const& builder ) = 0; + + virtual void handleFatalErrorCondition( StringRef message ) = 0; + + virtual void handleExpr + ( AssertionInfo const& info, + ITransientExpression const& expr, + AssertionReaction& reaction ) = 0; + virtual void handleMessage + ( AssertionInfo const& info, + ResultWas::OfType resultType, + StringRef const& message, + AssertionReaction& reaction ) = 0; + virtual void handleUnexpectedExceptionNotThrown + ( AssertionInfo const& info, + AssertionReaction& reaction ) = 0; + virtual void handleUnexpectedInflightException + ( AssertionInfo const& info, + std::string const& message, + AssertionReaction& reaction ) = 0; + virtual void handleIncomplete + ( AssertionInfo const& info ) = 0; + virtual void handleNonExpr + ( AssertionInfo const &info, + ResultWas::OfType resultType, + AssertionReaction &reaction ) = 0; + + virtual bool lastAssertionPassed() = 0; + virtual void assertionPassed() = 0; + + // Deprecated, do not use: + virtual std::string getCurrentTestName() const = 0; + virtual const AssertionResult* getLastResult() const = 0; + virtual void exceptionEarlyReported() = 0; + }; + + IResultCapture& getResultCapture(); +} + +// end catch_interfaces_capture.h +namespace Catch { + + struct TestFailureException{}; + struct AssertionResultData; + struct IResultCapture; + class RunContext; + + class LazyExpression { + friend class AssertionHandler; + friend struct AssertionStats; + friend class RunContext; + + ITransientExpression const* m_transientExpression = nullptr; + bool m_isNegated; + public: + LazyExpression( bool isNegated ); + LazyExpression( LazyExpression const& other ); + LazyExpression& operator = ( LazyExpression const& ) = delete; + + explicit operator bool() const; + + friend auto operator << ( std::ostream& os, LazyExpression const& lazyExpr ) -> std::ostream&; + }; + + struct AssertionReaction { + bool shouldDebugBreak = false; + bool shouldThrow = false; + }; + + class AssertionHandler { + AssertionInfo m_assertionInfo; + AssertionReaction m_reaction; + bool m_completed = false; + IResultCapture& m_resultCapture; + + public: + AssertionHandler + ( StringRef const& macroName, + SourceLineInfo const& lineInfo, + StringRef capturedExpression, + ResultDisposition::Flags resultDisposition ); + ~AssertionHandler() { + if ( !m_completed ) { + m_resultCapture.handleIncomplete( m_assertionInfo ); + } + } + + template<typename T> + void handleExpr( ExprLhs<T> const& expr ) { + handleExpr( expr.makeUnaryExpr() ); + } + void handleExpr( ITransientExpression const& expr ); + + void handleMessage(ResultWas::OfType resultType, StringRef const& message); + + void handleExceptionThrownAsExpected(); + void handleUnexpectedExceptionNotThrown(); + void handleExceptionNotThrownAsExpected(); + void handleThrowingCallSkipped(); + void handleUnexpectedInflightException(); + + void complete(); + void setCompleted(); + + // query + auto allowThrows() const -> bool; + }; + + void handleExceptionMatchExpr( AssertionHandler& handler, std::string const& str, StringRef const& matcherString ); + +} // namespace Catch + +// end catch_assertionhandler.h +// start catch_message.h + +#include <string> +#include <vector> + +namespace Catch { + + struct MessageInfo { + MessageInfo( StringRef const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ); + + StringRef macroName; + std::string message; + SourceLineInfo lineInfo; + ResultWas::OfType type; + unsigned int sequence; + + bool operator == ( MessageInfo const& other ) const; + bool operator < ( MessageInfo const& other ) const; + private: + static unsigned int globalCount; + }; + + struct MessageStream { + + template<typename T> + MessageStream& operator << ( T const& value ) { + m_stream << value; + return *this; + } + + ReusableStringStream m_stream; + }; + + struct MessageBuilder : MessageStream { + MessageBuilder( StringRef const& macroName, + SourceLineInfo const& lineInfo, + ResultWas::OfType type ); + + template<typename T> + MessageBuilder& operator << ( T const& value ) { + m_stream << value; + return *this; + } + + MessageInfo m_info; + }; + + class ScopedMessage { + public: + explicit ScopedMessage( MessageBuilder const& builder ); + ScopedMessage( ScopedMessage& duplicate ) = delete; + ScopedMessage( ScopedMessage&& old ); + ~ScopedMessage(); + + MessageInfo m_info; + bool m_moved; + }; + + class Capturer { + std::vector<MessageInfo> m_messages; + IResultCapture& m_resultCapture = getResultCapture(); + size_t m_captured = 0; + public: + Capturer( StringRef macroName, SourceLineInfo const& lineInfo, ResultWas::OfType resultType, StringRef names ); + ~Capturer(); + + void captureValue( size_t index, std::string const& value ); + + template<typename T> + void captureValues( size_t index, T const& value ) { + captureValue( index, Catch::Detail::stringify( value ) ); + } + + template<typename T, typename... Ts> + void captureValues( size_t index, T const& value, Ts const&... values ) { + captureValue( index, Catch::Detail::stringify(value) ); + captureValues( index+1, values... ); + } + }; + +} // end namespace Catch + +// end catch_message.h +#if !defined(CATCH_CONFIG_DISABLE) + +#if !defined(CATCH_CONFIG_DISABLE_STRINGIFICATION) + #define CATCH_INTERNAL_STRINGIFY(...) #__VA_ARGS__ +#else + #define CATCH_INTERNAL_STRINGIFY(...) "Disabled by CATCH_CONFIG_DISABLE_STRINGIFICATION" +#endif + +#if defined(CATCH_CONFIG_FAST_COMPILE) || defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + +/////////////////////////////////////////////////////////////////////////////// +// Another way to speed-up compilation is to omit local try-catch for REQUIRE* +// macros. +#define INTERNAL_CATCH_TRY +#define INTERNAL_CATCH_CATCH( capturer ) + +#else // CATCH_CONFIG_FAST_COMPILE + +#define INTERNAL_CATCH_TRY try +#define INTERNAL_CATCH_CATCH( handler ) catch(...) { handler.handleUnexpectedInflightException(); } + +#endif + +#define INTERNAL_CATCH_REACT( handler ) handler.complete(); + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TEST( macroName, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition ); \ + INTERNAL_CATCH_TRY { \ + CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + catchAssertionHandler.handleExpr( Catch::Decomposer() <= __VA_ARGS__ ); \ + CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + } INTERNAL_CATCH_CATCH( catchAssertionHandler ) \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( (void)0, (false) && static_cast<bool>( !!(__VA_ARGS__) ) ) // the expression here is never evaluated at runtime but it forces the compiler to give it a look + // The double negation silences MSVC's C4800 warning, the static_cast forces short-circuit evaluation if the type has overloaded &&. + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_IF( macroName, resultDisposition, ... ) \ + INTERNAL_CATCH_TEST( macroName, resultDisposition, __VA_ARGS__ ); \ + if( Catch::getResultCapture().lastAssertionPassed() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_ELSE( macroName, resultDisposition, ... ) \ + INTERNAL_CATCH_TEST( macroName, resultDisposition, __VA_ARGS__ ); \ + if( !Catch::getResultCapture().lastAssertionPassed() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_NO_THROW( macroName, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition ); \ + try { \ + static_cast<void>(__VA_ARGS__); \ + catchAssertionHandler.handleExceptionNotThrownAsExpected(); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleUnexpectedInflightException(); \ + } \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS( macroName, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast<void>(__VA_ARGS__); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleExceptionThrownAsExpected(); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS_AS( macroName, exceptionType, resultDisposition, expr ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(expr) ", " CATCH_INTERNAL_STRINGIFY(exceptionType), resultDisposition ); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast<void>(expr); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( exceptionType const& ) { \ + catchAssertionHandler.handleExceptionThrownAsExpected(); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleUnexpectedInflightException(); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::StringRef(), resultDisposition ); \ + catchAssertionHandler.handleMessage( messageType, ( Catch::MessageStream() << __VA_ARGS__ + ::Catch::StreamEndStop() ).m_stream.str() ); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_CAPTURE( varName, macroName, ... ) \ + auto varName = Catch::Capturer( macroName, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info, #__VA_ARGS__ ); \ + varName.captureValues( 0, __VA_ARGS__ ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_INFO( macroName, log ) \ + Catch::ScopedMessage INTERNAL_CATCH_UNIQUE_NAME( scopedMessage )( Catch::MessageBuilder( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log ); + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_UNSCOPED_INFO( macroName, log ) \ + Catch::getResultCapture().emplaceUnscopedMessage( Catch::MessageBuilder( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log ) + +/////////////////////////////////////////////////////////////////////////////// +// Although this is matcher-based, it can be used with just a string +#define INTERNAL_CATCH_THROWS_STR_MATCHES( macroName, resultDisposition, matcher, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast<void>(__VA_ARGS__); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( ... ) { \ + Catch::handleExceptionMatchExpr( catchAssertionHandler, matcher, #matcher##_catch_sr ); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +#endif // CATCH_CONFIG_DISABLE + +// end catch_capture.hpp +// start catch_section.h + +// start catch_section_info.h + +// start catch_totals.h + +#include <cstddef> + +namespace Catch { + + struct Counts { + Counts operator - ( Counts const& other ) const; + Counts& operator += ( Counts const& other ); + + std::size_t total() const; + bool allPassed() const; + bool allOk() const; + + std::size_t passed = 0; + std::size_t failed = 0; + std::size_t failedButOk = 0; + }; + + struct Totals { + + Totals operator - ( Totals const& other ) const; + Totals& operator += ( Totals const& other ); + + Totals delta( Totals const& prevTotals ) const; + + int error = 0; + Counts assertions; + Counts testCases; + }; +} + +// end catch_totals.h +#include <string> + +namespace Catch { + + struct SectionInfo { + SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name ); + + // Deprecated + SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name, + std::string const& ) : SectionInfo( _lineInfo, _name ) {} + + std::string name; + std::string description; // !Deprecated: this will always be empty + SourceLineInfo lineInfo; + }; + + struct SectionEndInfo { + SectionInfo sectionInfo; + Counts prevAssertions; + double durationInSeconds; + }; + +} // end namespace Catch + +// end catch_section_info.h +// start catch_timer.h + +#include <cstdint> + +namespace Catch { + + auto getCurrentNanosecondsSinceEpoch() -> uint64_t; + auto getEstimatedClockResolution() -> uint64_t; + + class Timer { + uint64_t m_nanoseconds = 0; + public: + void start(); + auto getElapsedNanoseconds() const -> uint64_t; + auto getElapsedMicroseconds() const -> uint64_t; + auto getElapsedMilliseconds() const -> unsigned int; + auto getElapsedSeconds() const -> double; + }; + +} // namespace Catch + +// end catch_timer.h +#include <string> + +namespace Catch { + + class Section : NonCopyable { + public: + Section( SectionInfo const& info ); + ~Section(); + + // This indicates whether the section should be executed or not + explicit operator bool() const; + + private: + SectionInfo m_info; + + std::string m_name; + Counts m_assertions; + bool m_sectionIncluded; + Timer m_timer; + }; + +} // end namespace Catch + +#define INTERNAL_CATCH_SECTION( ... ) \ + CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \ + if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, __VA_ARGS__ ) ) \ + CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS + +#define INTERNAL_CATCH_DYNAMIC_SECTION( ... ) \ + CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \ + if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, (Catch::ReusableStringStream() << __VA_ARGS__).str() ) ) \ + CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS + +// end catch_section.h +// start catch_benchmark.h + +#include <cstdint> +#include <string> + +namespace Catch { + + class BenchmarkLooper { + + std::string m_name; + std::size_t m_count = 0; + std::size_t m_iterationsToRun = 1; + uint64_t m_resolution; + Timer m_timer; + + static auto getResolution() -> uint64_t; + public: + // Keep most of this inline as it's on the code path that is being timed + BenchmarkLooper( StringRef name ) + : m_name( name ), + m_resolution( getResolution() ) + { + reportStart(); + m_timer.start(); + } + + explicit operator bool() { + if( m_count < m_iterationsToRun ) + return true; + return needsMoreIterations(); + } + + void increment() { + ++m_count; + } + + void reportStart(); + auto needsMoreIterations() -> bool; + }; + +} // end namespace Catch + +#define BENCHMARK( name ) \ + for( Catch::BenchmarkLooper looper( name ); looper; looper.increment() ) + +// end catch_benchmark.h +// start catch_interfaces_exception.h + +// start catch_interfaces_registry_hub.h + +#include <string> +#include <memory> + +namespace Catch { + + class TestCase; + struct ITestCaseRegistry; + struct IExceptionTranslatorRegistry; + struct IExceptionTranslator; + struct IReporterRegistry; + struct IReporterFactory; + struct ITagAliasRegistry; + class StartupExceptionRegistry; + + using IReporterFactoryPtr = std::shared_ptr<IReporterFactory>; + + struct IRegistryHub { + virtual ~IRegistryHub(); + + virtual IReporterRegistry const& getReporterRegistry() const = 0; + virtual ITestCaseRegistry const& getTestCaseRegistry() const = 0; + virtual ITagAliasRegistry const& getTagAliasRegistry() const = 0; + + virtual IExceptionTranslatorRegistry const& getExceptionTranslatorRegistry() const = 0; + + virtual StartupExceptionRegistry const& getStartupExceptionRegistry() const = 0; + }; + + struct IMutableRegistryHub { + virtual ~IMutableRegistryHub(); + virtual void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) = 0; + virtual void registerListener( IReporterFactoryPtr const& factory ) = 0; + virtual void registerTest( TestCase const& testInfo ) = 0; + virtual void registerTranslator( const IExceptionTranslator* translator ) = 0; + virtual void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) = 0; + virtual void registerStartupException() noexcept = 0; + }; + + IRegistryHub const& getRegistryHub(); + IMutableRegistryHub& getMutableRegistryHub(); + void cleanUp(); + std::string translateActiveException(); + +} + +// end catch_interfaces_registry_hub.h +#if defined(CATCH_CONFIG_DISABLE) + #define INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( translatorName, signature) \ + static std::string translatorName( signature ) +#endif + +#include <exception> +#include <string> +#include <vector> + +namespace Catch { + using exceptionTranslateFunction = std::string(*)(); + + struct IExceptionTranslator; + using ExceptionTranslators = std::vector<std::unique_ptr<IExceptionTranslator const>>; + + struct IExceptionTranslator { + virtual ~IExceptionTranslator(); + virtual std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const = 0; + }; + + struct IExceptionTranslatorRegistry { + virtual ~IExceptionTranslatorRegistry(); + + virtual std::string translateActiveException() const = 0; + }; + + class ExceptionTranslatorRegistrar { + template<typename T> + class ExceptionTranslator : public IExceptionTranslator { + public: + + ExceptionTranslator( std::string(*translateFunction)( T& ) ) + : m_translateFunction( translateFunction ) + {} + + std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const override { + try { + if( it == itEnd ) + std::rethrow_exception(std::current_exception()); + else + return (*it)->translate( it+1, itEnd ); + } + catch( T& ex ) { + return m_translateFunction( ex ); + } + } + + protected: + std::string(*m_translateFunction)( T& ); + }; + + public: + template<typename T> + ExceptionTranslatorRegistrar( std::string(*translateFunction)( T& ) ) { + getMutableRegistryHub().registerTranslator + ( new ExceptionTranslator<T>( translateFunction ) ); + } + }; +} + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION2( translatorName, signature ) \ + static std::string translatorName( signature ); \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::ExceptionTranslatorRegistrar INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionRegistrar )( &translatorName ); } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + static std::string translatorName( signature ) + +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION2( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature ) + +// end catch_interfaces_exception.h +// start catch_approx.h + +#include <type_traits> + +namespace Catch { +namespace Detail { + + class Approx { + private: + bool equalityComparisonImpl(double other) const; + // Validates the new margin (margin >= 0) + // out-of-line to avoid including stdexcept in the header + void setMargin(double margin); + // Validates the new epsilon (0 < epsilon < 1) + // out-of-line to avoid including stdexcept in the header + void setEpsilon(double epsilon); + + public: + explicit Approx ( double value ); + + static Approx custom(); + + Approx operator-() const; + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + Approx operator()( T const& value ) { + Approx approx( static_cast<double>(value) ); + approx.m_epsilon = m_epsilon; + approx.m_margin = m_margin; + approx.m_scale = m_scale; + return approx; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + explicit Approx( T const& value ): Approx(static_cast<double>(value)) + {} + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator == ( const T& lhs, Approx const& rhs ) { + auto lhs_v = static_cast<double>(lhs); + return rhs.equalityComparisonImpl(lhs_v); + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator == ( Approx const& lhs, const T& rhs ) { + return operator==( rhs, lhs ); + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator != ( T const& lhs, Approx const& rhs ) { + return !operator==( lhs, rhs ); + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator != ( Approx const& lhs, T const& rhs ) { + return !operator==( rhs, lhs ); + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator <= ( T const& lhs, Approx const& rhs ) { + return static_cast<double>(lhs) < rhs.m_value || lhs == rhs; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator <= ( Approx const& lhs, T const& rhs ) { + return lhs.m_value < static_cast<double>(rhs) || lhs == rhs; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator >= ( T const& lhs, Approx const& rhs ) { + return static_cast<double>(lhs) > rhs.m_value || lhs == rhs; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + friend bool operator >= ( Approx const& lhs, T const& rhs ) { + return lhs.m_value > static_cast<double>(rhs) || lhs == rhs; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + Approx& epsilon( T const& newEpsilon ) { + double epsilonAsDouble = static_cast<double>(newEpsilon); + setEpsilon(epsilonAsDouble); + return *this; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + Approx& margin( T const& newMargin ) { + double marginAsDouble = static_cast<double>(newMargin); + setMargin(marginAsDouble); + return *this; + } + + template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type> + Approx& scale( T const& newScale ) { + m_scale = static_cast<double>(newScale); + return *this; + } + + std::string toString() const; + + private: + double m_epsilon; + double m_margin; + double m_scale; + double m_value; + }; +} // end namespace Detail + +namespace literals { + Detail::Approx operator "" _a(long double val); + Detail::Approx operator "" _a(unsigned long long val); +} // end namespace literals + +template<> +struct StringMaker<Catch::Detail::Approx> { + static std::string convert(Catch::Detail::Approx const& value); +}; + +} // end namespace Catch + +// end catch_approx.h +// start catch_string_manip.h + +#include <string> +#include <iosfwd> + +namespace Catch { + + bool startsWith( std::string const& s, std::string const& prefix ); + bool startsWith( std::string const& s, char prefix ); + bool endsWith( std::string const& s, std::string const& suffix ); + bool endsWith( std::string const& s, char suffix ); + bool contains( std::string const& s, std::string const& infix ); + void toLowerInPlace( std::string& s ); + std::string toLower( std::string const& s ); + std::string trim( std::string const& str ); + bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ); + + struct pluralise { + pluralise( std::size_t count, std::string const& label ); + + friend std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ); + + std::size_t m_count; + std::string m_label; + }; +} + +// end catch_string_manip.h +#ifndef CATCH_CONFIG_DISABLE_MATCHERS +// start catch_capture_matchers.h + +// start catch_matchers.h + +#include <string> +#include <vector> + +namespace Catch { +namespace Matchers { + namespace Impl { + + template<typename ArgT> struct MatchAllOf; + template<typename ArgT> struct MatchAnyOf; + template<typename ArgT> struct MatchNotOf; + + class MatcherUntypedBase { + public: + MatcherUntypedBase() = default; + MatcherUntypedBase ( MatcherUntypedBase const& ) = default; + MatcherUntypedBase& operator = ( MatcherUntypedBase const& ) = delete; + std::string toString() const; + + protected: + virtual ~MatcherUntypedBase(); + virtual std::string describe() const = 0; + mutable std::string m_cachedToString; + }; + +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wnon-virtual-dtor" +#endif + + template<typename ObjectT> + struct MatcherMethod { + virtual bool match( ObjectT const& arg ) const = 0; + }; + +#ifdef __clang__ +# pragma clang diagnostic pop +#endif + + template<typename T> + struct MatcherBase : MatcherUntypedBase, MatcherMethod<T> { + + MatchAllOf<T> operator && ( MatcherBase const& other ) const; + MatchAnyOf<T> operator || ( MatcherBase const& other ) const; + MatchNotOf<T> operator ! () const; + }; + + template<typename ArgT> + struct MatchAllOf : MatcherBase<ArgT> { + bool match( ArgT const& arg ) const override { + for( auto matcher : m_matchers ) { + if (!matcher->match(arg)) + return false; + } + return true; + } + std::string describe() const override { + std::string description; + description.reserve( 4 + m_matchers.size()*32 ); + description += "( "; + bool first = true; + for( auto matcher : m_matchers ) { + if( first ) + first = false; + else + description += " and "; + description += matcher->toString(); + } + description += " )"; + return description; + } + + MatchAllOf<ArgT>& operator && ( MatcherBase<ArgT> const& other ) { + m_matchers.push_back( &other ); + return *this; + } + + std::vector<MatcherBase<ArgT> const*> m_matchers; + }; + template<typename ArgT> + struct MatchAnyOf : MatcherBase<ArgT> { + + bool match( ArgT const& arg ) const override { + for( auto matcher : m_matchers ) { + if (matcher->match(arg)) + return true; + } + return false; + } + std::string describe() const override { + std::string description; + description.reserve( 4 + m_matchers.size()*32 ); + description += "( "; + bool first = true; + for( auto matcher : m_matchers ) { + if( first ) + first = false; + else + description += " or "; + description += matcher->toString(); + } + description += " )"; + return description; + } + + MatchAnyOf<ArgT>& operator || ( MatcherBase<ArgT> const& other ) { + m_matchers.push_back( &other ); + return *this; + } + + std::vector<MatcherBase<ArgT> const*> m_matchers; + }; + + template<typename ArgT> + struct MatchNotOf : MatcherBase<ArgT> { + + MatchNotOf( MatcherBase<ArgT> const& underlyingMatcher ) : m_underlyingMatcher( underlyingMatcher ) {} + + bool match( ArgT const& arg ) const override { + return !m_underlyingMatcher.match( arg ); + } + + std::string describe() const override { + return "not " + m_underlyingMatcher.toString(); + } + MatcherBase<ArgT> const& m_underlyingMatcher; + }; + + template<typename T> + MatchAllOf<T> MatcherBase<T>::operator && ( MatcherBase const& other ) const { + return MatchAllOf<T>() && *this && other; + } + template<typename T> + MatchAnyOf<T> MatcherBase<T>::operator || ( MatcherBase const& other ) const { + return MatchAnyOf<T>() || *this || other; + } + template<typename T> + MatchNotOf<T> MatcherBase<T>::operator ! () const { + return MatchNotOf<T>( *this ); + } + + } // namespace Impl + +} // namespace Matchers + +using namespace Matchers; +using Matchers::Impl::MatcherBase; + +} // namespace Catch + +// end catch_matchers.h +// start catch_matchers_floating.h + +#include <type_traits> +#include <cmath> + +namespace Catch { +namespace Matchers { + + namespace Floating { + + enum class FloatingPointKind : uint8_t; + + struct WithinAbsMatcher : MatcherBase<double> { + WithinAbsMatcher(double target, double margin); + bool match(double const& matchee) const override; + std::string describe() const override; + private: + double m_target; + double m_margin; + }; + + struct WithinUlpsMatcher : MatcherBase<double> { + WithinUlpsMatcher(double target, int ulps, FloatingPointKind baseType); + bool match(double const& matchee) const override; + std::string describe() const override; + private: + double m_target; + int m_ulps; + FloatingPointKind m_type; + }; + + } // namespace Floating + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + Floating::WithinUlpsMatcher WithinULP(double target, int maxUlpDiff); + Floating::WithinUlpsMatcher WithinULP(float target, int maxUlpDiff); + Floating::WithinAbsMatcher WithinAbs(double target, double margin); + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_floating.h +// start catch_matchers_generic.hpp + +#include <functional> +#include <string> + +namespace Catch { +namespace Matchers { +namespace Generic { + +namespace Detail { + std::string finalizeDescription(const std::string& desc); +} + +template <typename T> +class PredicateMatcher : public MatcherBase<T> { + std::function<bool(T const&)> m_predicate; + std::string m_description; +public: + + PredicateMatcher(std::function<bool(T const&)> const& elem, std::string const& descr) + :m_predicate(std::move(elem)), + m_description(Detail::finalizeDescription(descr)) + {} + + bool match( T const& item ) const override { + return m_predicate(item); + } + + std::string describe() const override { + return m_description; + } +}; + +} // namespace Generic + + // The following functions create the actual matcher objects. + // The user has to explicitly specify type to the function, because + // infering std::function<bool(T const&)> is hard (but possible) and + // requires a lot of TMP. + template<typename T> + Generic::PredicateMatcher<T> Predicate(std::function<bool(T const&)> const& predicate, std::string const& description = "") { + return Generic::PredicateMatcher<T>(predicate, description); + } + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_generic.hpp +// start catch_matchers_string.h + +#include <string> + +namespace Catch { +namespace Matchers { + + namespace StdString { + + struct CasedString + { + CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity ); + std::string adjustString( std::string const& str ) const; + std::string caseSensitivitySuffix() const; + + CaseSensitive::Choice m_caseSensitivity; + std::string m_str; + }; + + struct StringMatcherBase : MatcherBase<std::string> { + StringMatcherBase( std::string const& operation, CasedString const& comparator ); + std::string describe() const override; + + CasedString m_comparator; + std::string m_operation; + }; + + struct EqualsMatcher : StringMatcherBase { + EqualsMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + struct ContainsMatcher : StringMatcherBase { + ContainsMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + struct StartsWithMatcher : StringMatcherBase { + StartsWithMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + struct EndsWithMatcher : StringMatcherBase { + EndsWithMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + + struct RegexMatcher : MatcherBase<std::string> { + RegexMatcher( std::string regex, CaseSensitive::Choice caseSensitivity ); + bool match( std::string const& matchee ) const override; + std::string describe() const override; + + private: + std::string m_regex; + CaseSensitive::Choice m_caseSensitivity; + }; + + } // namespace StdString + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + + StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::RegexMatcher Matches( std::string const& regex, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_string.h +// start catch_matchers_vector.h + +#include <algorithm> + +namespace Catch { +namespace Matchers { + + namespace Vector { + namespace Detail { + template <typename InputIterator, typename T> + size_t count(InputIterator first, InputIterator last, T const& item) { + size_t cnt = 0; + for (; first != last; ++first) { + if (*first == item) { + ++cnt; + } + } + return cnt; + } + template <typename InputIterator, typename T> + bool contains(InputIterator first, InputIterator last, T const& item) { + for (; first != last; ++first) { + if (*first == item) { + return true; + } + } + return false; + } + } + + template<typename T> + struct ContainsElementMatcher : MatcherBase<std::vector<T>> { + + ContainsElementMatcher(T const &comparator) : m_comparator( comparator) {} + + bool match(std::vector<T> const &v) const override { + for (auto const& el : v) { + if (el == m_comparator) { + return true; + } + } + return false; + } + + std::string describe() const override { + return "Contains: " + ::Catch::Detail::stringify( m_comparator ); + } + + T const& m_comparator; + }; + + template<typename T> + struct ContainsMatcher : MatcherBase<std::vector<T>> { + + ContainsMatcher(std::vector<T> const &comparator) : m_comparator( comparator ) {} + + bool match(std::vector<T> const &v) const override { + // !TBD: see note in EqualsMatcher + if (m_comparator.size() > v.size()) + return false; + for (auto const& comparator : m_comparator) { + auto present = false; + for (const auto& el : v) { + if (el == comparator) { + present = true; + break; + } + } + if (!present) { + return false; + } + } + return true; + } + std::string describe() const override { + return "Contains: " + ::Catch::Detail::stringify( m_comparator ); + } + + std::vector<T> const& m_comparator; + }; + + template<typename T> + struct EqualsMatcher : MatcherBase<std::vector<T>> { + + EqualsMatcher(std::vector<T> const &comparator) : m_comparator( comparator ) {} + + bool match(std::vector<T> const &v) const override { + // !TBD: This currently works if all elements can be compared using != + // - a more general approach would be via a compare template that defaults + // to using !=. but could be specialised for, e.g. std::vector<T> etc + // - then just call that directly + if (m_comparator.size() != v.size()) + return false; + for (std::size_t i = 0; i < v.size(); ++i) + if (m_comparator[i] != v[i]) + return false; + return true; + } + std::string describe() const override { + return "Equals: " + ::Catch::Detail::stringify( m_comparator ); + } + std::vector<T> const& m_comparator; + }; + + template<typename T> + struct UnorderedEqualsMatcher : MatcherBase<std::vector<T>> { + UnorderedEqualsMatcher(std::vector<T> const& target) : m_target(target) {} + bool match(std::vector<T> const& vec) const override { + // Note: This is a reimplementation of std::is_permutation, + // because I don't want to include <algorithm> inside the common path + if (m_target.size() != vec.size()) { + return false; + } + auto lfirst = m_target.begin(), llast = m_target.end(); + auto rfirst = vec.begin(), rlast = vec.end(); + // Cut common prefix to optimize checking of permuted parts + while (lfirst != llast && *lfirst == *rfirst) { + ++lfirst; ++rfirst; + } + if (lfirst == llast) { + return true; + } + + for (auto mid = lfirst; mid != llast; ++mid) { + // Skip already counted items + if (Detail::contains(lfirst, mid, *mid)) { + continue; + } + size_t num_vec = Detail::count(rfirst, rlast, *mid); + if (num_vec == 0 || Detail::count(lfirst, llast, *mid) != num_vec) { + return false; + } + } + + return true; + } + + std::string describe() const override { + return "UnorderedEquals: " + ::Catch::Detail::stringify(m_target); + } + private: + std::vector<T> const& m_target; + }; + + } // namespace Vector + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + + template<typename T> + Vector::ContainsMatcher<T> Contains( std::vector<T> const& comparator ) { + return Vector::ContainsMatcher<T>( comparator ); + } + + template<typename T> + Vector::ContainsElementMatcher<T> VectorContains( T const& comparator ) { + return Vector::ContainsElementMatcher<T>( comparator ); + } + + template<typename T> + Vector::EqualsMatcher<T> Equals( std::vector<T> const& comparator ) { + return Vector::EqualsMatcher<T>( comparator ); + } + + template<typename T> + Vector::UnorderedEqualsMatcher<T> UnorderedEquals(std::vector<T> const& target) { + return Vector::UnorderedEqualsMatcher<T>(target); + } + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_vector.h +namespace Catch { + + template<typename ArgT, typename MatcherT> + class MatchExpr : public ITransientExpression { + ArgT const& m_arg; + MatcherT m_matcher; + StringRef m_matcherString; + public: + MatchExpr( ArgT const& arg, MatcherT const& matcher, StringRef const& matcherString ) + : ITransientExpression{ true, matcher.match( arg ) }, + m_arg( arg ), + m_matcher( matcher ), + m_matcherString( matcherString ) + {} + + void streamReconstructedExpression( std::ostream &os ) const override { + auto matcherAsString = m_matcher.toString(); + os << Catch::Detail::stringify( m_arg ) << ' '; + if( matcherAsString == Detail::unprintableString ) + os << m_matcherString; + else + os << matcherAsString; + } + }; + + using StringMatcher = Matchers::Impl::MatcherBase<std::string>; + + void handleExceptionMatchExpr( AssertionHandler& handler, StringMatcher const& matcher, StringRef const& matcherString ); + + template<typename ArgT, typename MatcherT> + auto makeMatchExpr( ArgT const& arg, MatcherT const& matcher, StringRef const& matcherString ) -> MatchExpr<ArgT, MatcherT> { + return MatchExpr<ArgT, MatcherT>( arg, matcher, matcherString ); + } + +} // namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CHECK_THAT( macroName, matcher, resultDisposition, arg ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(arg) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \ + INTERNAL_CATCH_TRY { \ + catchAssertionHandler.handleExpr( Catch::makeMatchExpr( arg, matcher, #matcher##_catch_sr ) ); \ + } INTERNAL_CATCH_CATCH( catchAssertionHandler ) \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS_MATCHES( macroName, exceptionType, resultDisposition, matcher, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__) ", " CATCH_INTERNAL_STRINGIFY(exceptionType) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast<void>(__VA_ARGS__ ); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( exceptionType const& ex ) { \ + catchAssertionHandler.handleExpr( Catch::makeMatchExpr( ex, matcher, #matcher##_catch_sr ) ); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleUnexpectedInflightException(); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +// end catch_capture_matchers.h +#endif +// start catch_generators.hpp + +// start catch_interfaces_generatortracker.h + + +#include <memory> + +namespace Catch { + + namespace Generators { + class GeneratorUntypedBase { + public: + GeneratorUntypedBase() = default; + virtual ~GeneratorUntypedBase(); + // Attempts to move the generator to the next element + // + // Returns true iff the move succeeded (and a valid element + // can be retrieved). + virtual bool next() = 0; + }; + using GeneratorBasePtr = std::unique_ptr<GeneratorUntypedBase>; + + } // namespace Generators + + struct IGeneratorTracker { + virtual ~IGeneratorTracker(); + virtual auto hasGenerator() const -> bool = 0; + virtual auto getGenerator() const -> Generators::GeneratorBasePtr const& = 0; + virtual void setGenerator( Generators::GeneratorBasePtr&& generator ) = 0; + }; + +} // namespace Catch + +// end catch_interfaces_generatortracker.h +// start catch_enforce.h + +#include <stdexcept> + +namespace Catch { +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + template <typename Ex> + [[noreturn]] + void throw_exception(Ex const& e) { + throw e; + } +#else // ^^ Exceptions are enabled // Exceptions are disabled vv + [[noreturn]] + void throw_exception(std::exception const& e); +#endif +} // namespace Catch; + +#define CATCH_PREPARE_EXCEPTION( type, msg ) \ + type( ( Catch::ReusableStringStream() << msg ).str() ) +#define CATCH_INTERNAL_ERROR( msg ) \ + Catch::throw_exception(CATCH_PREPARE_EXCEPTION( std::logic_error, CATCH_INTERNAL_LINEINFO << ": Internal Catch error: " << msg)) +#define CATCH_ERROR( msg ) \ + Catch::throw_exception(CATCH_PREPARE_EXCEPTION( std::domain_error, msg )) +#define CATCH_RUNTIME_ERROR( msg ) \ + Catch::throw_exception(CATCH_PREPARE_EXCEPTION( std::runtime_error, msg )) +#define CATCH_ENFORCE( condition, msg ) \ + do{ if( !(condition) ) CATCH_ERROR( msg ); } while(false) + +// end catch_enforce.h +#include <memory> +#include <vector> +#include <cassert> + +#include <utility> +#include <exception> + +namespace Catch { + +class GeneratorException : public std::exception { + const char* const m_msg = ""; + +public: + GeneratorException(const char* msg): + m_msg(msg) + {} + + const char* what() const noexcept override final; +}; + +namespace Generators { + + // !TBD move this into its own location? + namespace pf{ + template<typename T, typename... Args> + std::unique_ptr<T> make_unique( Args&&... args ) { + return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); + } + } + + template<typename T> + struct IGenerator : GeneratorUntypedBase { + virtual ~IGenerator() = default; + + // Returns the current element of the generator + // + // \Precondition The generator is either freshly constructed, + // or the last call to `next()` returned true + virtual T const& get() const = 0; + using type = T; + }; + + template<typename T> + class SingleValueGenerator final : public IGenerator<T> { + T m_value; + public: + SingleValueGenerator(T const& value) : m_value( value ) {} + SingleValueGenerator(T&& value) : m_value(std::move(value)) {} + + T const& get() const override { + return m_value; + } + bool next() override { + return false; + } + }; + + template<typename T> + class FixedValuesGenerator final : public IGenerator<T> { + std::vector<T> m_values; + size_t m_idx = 0; + public: + FixedValuesGenerator( std::initializer_list<T> values ) : m_values( values ) {} + + T const& get() const override { + return m_values[m_idx]; + } + bool next() override { + ++m_idx; + return m_idx < m_values.size(); + } + }; + + template <typename T> + class GeneratorWrapper final { + std::unique_ptr<IGenerator<T>> m_generator; + public: + GeneratorWrapper(std::unique_ptr<IGenerator<T>> generator): + m_generator(std::move(generator)) + {} + T const& get() const { + return m_generator->get(); + } + bool next() { + return m_generator->next(); + } + }; + + template <typename T> + GeneratorWrapper<T> value(T&& value) { + return GeneratorWrapper<T>(pf::make_unique<SingleValueGenerator<T>>(std::forward<T>(value))); + } + template <typename T> + GeneratorWrapper<T> values(std::initializer_list<T> values) { + return GeneratorWrapper<T>(pf::make_unique<FixedValuesGenerator<T>>(values)); + } + + template<typename T> + class Generators : public IGenerator<T> { + std::vector<GeneratorWrapper<T>> m_generators; + size_t m_current = 0; + + void populate(GeneratorWrapper<T>&& generator) { + m_generators.emplace_back(std::move(generator)); + } + void populate(T&& val) { + m_generators.emplace_back(value(std::move(val))); + } + template<typename U> + void populate(U&& val) { + populate(T(std::move(val))); + } + template<typename U, typename... Gs> + void populate(U&& valueOrGenerator, Gs... moreGenerators) { + populate(std::forward<U>(valueOrGenerator)); + populate(std::forward<Gs>(moreGenerators)...); + } + + public: + template <typename... Gs> + Generators(Gs... moreGenerators) { + m_generators.reserve(sizeof...(Gs)); + populate(std::forward<Gs>(moreGenerators)...); + } + + T const& get() const override { + return m_generators[m_current].get(); + } + + bool next() override { + if (m_current >= m_generators.size()) { + return false; + } + const bool current_status = m_generators[m_current].next(); + if (!current_status) { + ++m_current; + } + return m_current < m_generators.size(); + } + }; + + template<typename... Ts> + GeneratorWrapper<std::tuple<Ts...>> table( std::initializer_list<std::tuple<typename std::decay<Ts>::type...>> tuples ) { + return values<std::tuple<Ts...>>( tuples ); + } + + // Tag type to signal that a generator sequence should convert arguments to a specific type + template <typename T> + struct as {}; + + template<typename T, typename... Gs> + auto makeGenerators( GeneratorWrapper<T>&& generator, Gs... moreGenerators ) -> Generators<T> { + return Generators<T>(std::move(generator), std::forward<Gs>(moreGenerators)...); + } + template<typename T> + auto makeGenerators( GeneratorWrapper<T>&& generator ) -> Generators<T> { + return Generators<T>(std::move(generator)); + } + template<typename T, typename... Gs> + auto makeGenerators( T&& val, Gs... moreGenerators ) -> Generators<T> { + return makeGenerators( value( std::forward<T>( val ) ), std::forward<Gs>( moreGenerators )... ); + } + template<typename T, typename U, typename... Gs> + auto makeGenerators( as<T>, U&& val, Gs... moreGenerators ) -> Generators<T> { + return makeGenerators( value( T( std::forward<U>( val ) ) ), std::forward<Gs>( moreGenerators )... ); + } + + auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker&; + + template<typename L> + // Note: The type after -> is weird, because VS2015 cannot parse + // the expression used in the typedef inside, when it is in + // return type. Yeah. + auto generate( SourceLineInfo const& lineInfo, L const& generatorExpression ) -> decltype(std::declval<decltype(generatorExpression())>().get()) { + using UnderlyingType = typename decltype(generatorExpression())::type; + + IGeneratorTracker& tracker = acquireGeneratorTracker( lineInfo ); + if (!tracker.hasGenerator()) { + tracker.setGenerator(pf::make_unique<Generators<UnderlyingType>>(generatorExpression())); + } + + auto const& generator = static_cast<IGenerator<UnderlyingType> const&>( *tracker.getGenerator() ); + return generator.get(); + } + +} // namespace Generators +} // namespace Catch + +#define GENERATE( ... ) \ + Catch::Generators::generate( CATCH_INTERNAL_LINEINFO, []{ using namespace Catch::Generators; return makeGenerators( __VA_ARGS__ ); } ) + +// end catch_generators.hpp +// start catch_generators_generic.hpp + +namespace Catch { +namespace Generators { + + template <typename T> + class TakeGenerator : public IGenerator<T> { + GeneratorWrapper<T> m_generator; + size_t m_returned = 0; + size_t m_target; + public: + TakeGenerator(size_t target, GeneratorWrapper<T>&& generator): + m_generator(std::move(generator)), + m_target(target) + { + assert(target != 0 && "Empty generators are not allowed"); + } + T const& get() const override { + return m_generator.get(); + } + bool next() override { + ++m_returned; + if (m_returned >= m_target) { + return false; + } + + const auto success = m_generator.next(); + // If the underlying generator does not contain enough values + // then we cut short as well + if (!success) { + m_returned = m_target; + } + return success; + } + }; + + template <typename T> + GeneratorWrapper<T> take(size_t target, GeneratorWrapper<T>&& generator) { + return GeneratorWrapper<T>(pf::make_unique<TakeGenerator<T>>(target, std::move(generator))); + } + + template <typename T, typename Predicate> + class FilterGenerator : public IGenerator<T> { + GeneratorWrapper<T> m_generator; + Predicate m_predicate; + public: + template <typename P = Predicate> + FilterGenerator(P&& pred, GeneratorWrapper<T>&& generator): + m_generator(std::move(generator)), + m_predicate(std::forward<P>(pred)) + { + if (!m_predicate(m_generator.get())) { + // It might happen that there are no values that pass the + // filter. In that case we throw an exception. + auto has_initial_value = next(); + if (!has_initial_value) { + Catch::throw_exception(GeneratorException("No valid value found in filtered generator")); + } + } + } + + T const& get() const override { + return m_generator.get(); + } + + bool next() override { + bool success = m_generator.next(); + if (!success) { + return false; + } + while (!m_predicate(m_generator.get()) && (success = m_generator.next()) == true); + return success; + } + }; + + template <typename T, typename Predicate> + GeneratorWrapper<T> filter(Predicate&& pred, GeneratorWrapper<T>&& generator) { + return GeneratorWrapper<T>(std::unique_ptr<IGenerator<T>>(pf::make_unique<FilterGenerator<T, Predicate>>(std::forward<Predicate>(pred), std::move(generator)))); + } + + template <typename T> + class RepeatGenerator : public IGenerator<T> { + GeneratorWrapper<T> m_generator; + mutable std::vector<T> m_returned; + size_t m_target_repeats; + size_t m_current_repeat = 0; + size_t m_repeat_index = 0; + public: + RepeatGenerator(size_t repeats, GeneratorWrapper<T>&& generator): + m_generator(std::move(generator)), + m_target_repeats(repeats) + { + assert(m_target_repeats > 0 && "Repeat generator must repeat at least once"); + } + + T const& get() const override { + if (m_current_repeat == 0) { + m_returned.push_back(m_generator.get()); + return m_returned.back(); + } + return m_returned[m_repeat_index]; + } + + bool next() override { + // There are 2 basic cases: + // 1) We are still reading the generator + // 2) We are reading our own cache + + // In the first case, we need to poke the underlying generator. + // If it happily moves, we are left in that state, otherwise it is time to start reading from our cache + if (m_current_repeat == 0) { + const auto success = m_generator.next(); + if (!success) { + ++m_current_repeat; + } + return m_current_repeat < m_target_repeats; + } + + // In the second case, we need to move indices forward and check that we haven't run up against the end + ++m_repeat_index; + if (m_repeat_index == m_returned.size()) { + m_repeat_index = 0; + ++m_current_repeat; + } + return m_current_repeat < m_target_repeats; + } + }; + + template <typename T> + GeneratorWrapper<T> repeat(size_t repeats, GeneratorWrapper<T>&& generator) { + return GeneratorWrapper<T>(pf::make_unique<RepeatGenerator<T>>(repeats, std::move(generator))); + } + + template <typename T, typename U, typename Func> + class MapGenerator : public IGenerator<T> { + // TBD: provide static assert for mapping function, for friendly error message + GeneratorWrapper<U> m_generator; + Func m_function; + // To avoid returning dangling reference, we have to save the values + T m_cache; + public: + template <typename F2 = Func> + MapGenerator(F2&& function, GeneratorWrapper<U>&& generator) : + m_generator(std::move(generator)), + m_function(std::forward<F2>(function)), + m_cache(m_function(m_generator.get())) + {} + + T const& get() const override { + return m_cache; + } + bool next() override { + const auto success = m_generator.next(); + if (success) { + m_cache = m_function(m_generator.get()); + } + return success; + } + }; + + template <typename T, typename U, typename Func> + GeneratorWrapper<T> map(Func&& function, GeneratorWrapper<U>&& generator) { + return GeneratorWrapper<T>( + pf::make_unique<MapGenerator<T, U, Func>>(std::forward<Func>(function), std::move(generator)) + ); + } + template <typename T, typename Func> + GeneratorWrapper<T> map(Func&& function, GeneratorWrapper<T>&& generator) { + return GeneratorWrapper<T>( + pf::make_unique<MapGenerator<T, T, Func>>(std::forward<Func>(function), std::move(generator)) + ); + } + + template <typename T> + class ChunkGenerator final : public IGenerator<std::vector<T>> { + std::vector<T> m_chunk; + size_t m_chunk_size; + GeneratorWrapper<T> m_generator; + bool m_used_up = false; + public: + ChunkGenerator(size_t size, GeneratorWrapper<T> generator) : + m_chunk_size(size), m_generator(std::move(generator)) + { + m_chunk.reserve(m_chunk_size); + m_chunk.push_back(m_generator.get()); + for (size_t i = 1; i < m_chunk_size; ++i) { + if (!m_generator.next()) { + Catch::throw_exception(GeneratorException("Not enough values to initialize the first chunk")); + } + m_chunk.push_back(m_generator.get()); + } + } + std::vector<T> const& get() const override { + return m_chunk; + } + bool next() override { + m_chunk.clear(); + for (size_t idx = 0; idx < m_chunk_size; ++idx) { + if (!m_generator.next()) { + return false; + } + m_chunk.push_back(m_generator.get()); + } + return true; + } + }; + + template <typename T> + GeneratorWrapper<std::vector<T>> chunk(size_t size, GeneratorWrapper<T>&& generator) { + return GeneratorWrapper<std::vector<T>>( + pf::make_unique<ChunkGenerator<T>>(size, std::move(generator)) + ); + } + +} // namespace Generators +} // namespace Catch + +// end catch_generators_generic.hpp +// start catch_generators_specific.hpp + +// start catch_context.h + +#include <memory> + +namespace Catch { + + struct IResultCapture; + struct IRunner; + struct IConfig; + struct IMutableContext; + + using IConfigPtr = std::shared_ptr<IConfig const>; + + struct IContext + { + virtual ~IContext(); + + virtual IResultCapture* getResultCapture() = 0; + virtual IRunner* getRunner() = 0; + virtual IConfigPtr const& getConfig() const = 0; + }; + + struct IMutableContext : IContext + { + virtual ~IMutableContext(); + virtual void setResultCapture( IResultCapture* resultCapture ) = 0; + virtual void setRunner( IRunner* runner ) = 0; + virtual void setConfig( IConfigPtr const& config ) = 0; + + private: + static IMutableContext *currentContext; + friend IMutableContext& getCurrentMutableContext(); + friend void cleanUpContext(); + static void createContext(); + }; + + inline IMutableContext& getCurrentMutableContext() + { + if( !IMutableContext::currentContext ) + IMutableContext::createContext(); + return *IMutableContext::currentContext; + } + + inline IContext& getCurrentContext() + { + return getCurrentMutableContext(); + } + + void cleanUpContext(); +} + +// end catch_context.h +// start catch_interfaces_config.h + +#include <iosfwd> +#include <string> +#include <vector> +#include <memory> + +namespace Catch { + + enum class Verbosity { + Quiet = 0, + Normal, + High + }; + + struct WarnAbout { enum What { + Nothing = 0x00, + NoAssertions = 0x01, + NoTests = 0x02 + }; }; + + struct ShowDurations { enum OrNot { + DefaultForReporter, + Always, + Never + }; }; + struct RunTests { enum InWhatOrder { + InDeclarationOrder, + InLexicographicalOrder, + InRandomOrder + }; }; + struct UseColour { enum YesOrNo { + Auto, + Yes, + No + }; }; + struct WaitForKeypress { enum When { + Never, + BeforeStart = 1, + BeforeExit = 2, + BeforeStartAndExit = BeforeStart | BeforeExit + }; }; + + class TestSpec; + + struct IConfig : NonCopyable { + + virtual ~IConfig(); + + virtual bool allowThrows() const = 0; + virtual std::ostream& stream() const = 0; + virtual std::string name() const = 0; + virtual bool includeSuccessfulResults() const = 0; + virtual bool shouldDebugBreak() const = 0; + virtual bool warnAboutMissingAssertions() const = 0; + virtual bool warnAboutNoTests() const = 0; + virtual int abortAfter() const = 0; + virtual bool showInvisibles() const = 0; + virtual ShowDurations::OrNot showDurations() const = 0; + virtual TestSpec const& testSpec() const = 0; + virtual bool hasTestFilters() const = 0; + virtual RunTests::InWhatOrder runOrder() const = 0; + virtual unsigned int rngSeed() const = 0; + virtual int benchmarkResolutionMultiple() const = 0; + virtual UseColour::YesOrNo useColour() const = 0; + virtual std::vector<std::string> const& getSectionsToRun() const = 0; + virtual Verbosity verbosity() const = 0; + }; + + using IConfigPtr = std::shared_ptr<IConfig const>; +} + +// end catch_interfaces_config.h +#include <random> + +namespace Catch { +namespace Generators { + +template <typename Float> +class RandomFloatingGenerator final : public IGenerator<Float> { + // FIXME: What is the right seed? + std::minstd_rand m_rand; + std::uniform_real_distribution<Float> m_dist; + Float m_current_number; +public: + + RandomFloatingGenerator(Float a, Float b): + m_rand(getCurrentContext().getConfig()->rngSeed()), + m_dist(a, b) { + static_cast<void>(next()); + } + + Float const& get() const override { + return m_current_number; + } + bool next() override { + m_current_number = m_dist(m_rand); + return true; + } +}; + +template <typename Integer> +class RandomIntegerGenerator final : public IGenerator<Integer> { + std::minstd_rand m_rand; + std::uniform_int_distribution<Integer> m_dist; + Integer m_current_number; +public: + + RandomIntegerGenerator(Integer a, Integer b): + m_rand(getCurrentContext().getConfig()->rngSeed()), + m_dist(a, b) { + static_cast<void>(next()); + } + + Integer const& get() const override { + return m_current_number; + } + bool next() override { + m_current_number = m_dist(m_rand); + return true; + } +}; + +// TODO: Ideally this would be also constrained against the various char types, +// but I don't expect users to run into that in practice. +template <typename T> +typename std::enable_if<std::is_integral<T>::value && !std::is_same<T, bool>::value, +GeneratorWrapper<T>>::type +random(T a, T b) { + return GeneratorWrapper<T>( + pf::make_unique<RandomIntegerGenerator<T>>(a, b) + ); +} + +template <typename T> +typename std::enable_if<std::is_floating_point<T>::value, +GeneratorWrapper<T>>::type +random(T a, T b) { + return GeneratorWrapper<T>( + pf::make_unique<RandomFloatingGenerator<T>>(a, b) + ); +} + +template <typename T> +class RangeGenerator final : public IGenerator<T> { + T m_current; + T m_end; + T m_step; + bool m_positive; + +public: + RangeGenerator(T const& start, T const& end, T const& step): + m_current(start), + m_end(end), + m_step(step), + m_positive(m_step > T(0)) + { + assert(m_current != m_end && "Range start and end cannot be equal"); + assert(m_step != T(0) && "Step size cannot be zero"); + assert(((m_positive && m_current <= m_end) || (!m_positive && m_current >= m_end)) && "Step moves away from end"); + } + + RangeGenerator(T const& start, T const& end): + RangeGenerator(start, end, (start < end) ? T(1) : T(-1)) + {} + + T const& get() const override { + return m_current; + } + + bool next() override { + m_current += m_step; + return (m_positive) ? (m_current < m_end) : (m_current > m_end); + } +}; + +template <typename T> +GeneratorWrapper<T> range(T const& start, T const& end, T const& step) { + static_assert(std::is_integral<T>::value && !std::is_same<T, bool>::value, "Type must be an integer"); + return GeneratorWrapper<T>(pf::make_unique<RangeGenerator<T>>(start, end, step)); +} + +template <typename T> +GeneratorWrapper<T> range(T const& start, T const& end) { + static_assert(std::is_integral<T>::value && !std::is_same<T, bool>::value, "Type must be an integer"); + return GeneratorWrapper<T>(pf::make_unique<RangeGenerator<T>>(start, end)); +} + +} // namespace Generators +} // namespace Catch + +// end catch_generators_specific.hpp + +// These files are included here so the single_include script doesn't put them +// in the conditionally compiled sections +// start catch_test_case_info.h + +#include <string> +#include <vector> +#include <memory> + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + struct ITestInvoker; + + struct TestCaseInfo { + enum SpecialProperties{ + None = 0, + IsHidden = 1 << 1, + ShouldFail = 1 << 2, + MayFail = 1 << 3, + Throws = 1 << 4, + NonPortable = 1 << 5, + Benchmark = 1 << 6 + }; + + TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::vector<std::string> const& _tags, + SourceLineInfo const& _lineInfo ); + + friend void setTags( TestCaseInfo& testCaseInfo, std::vector<std::string> tags ); + + bool isHidden() const; + bool throws() const; + bool okToFail() const; + bool expectedToFail() const; + + std::string tagsAsString() const; + + std::string name; + std::string className; + std::string description; + std::vector<std::string> tags; + std::vector<std::string> lcaseTags; + SourceLineInfo lineInfo; + SpecialProperties properties; + }; + + class TestCase : public TestCaseInfo { + public: + + TestCase( ITestInvoker* testCase, TestCaseInfo&& info ); + + TestCase withName( std::string const& _newName ) const; + + void invoke() const; + + TestCaseInfo const& getTestCaseInfo() const; + + bool operator == ( TestCase const& other ) const; + bool operator < ( TestCase const& other ) const; + + private: + std::shared_ptr<ITestInvoker> test; + }; + + TestCase makeTestCase( ITestInvoker* testCase, + std::string const& className, + NameAndTags const& nameAndTags, + SourceLineInfo const& lineInfo ); +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_test_case_info.h +// start catch_interfaces_runner.h + +namespace Catch { + + struct IRunner { + virtual ~IRunner(); + virtual bool aborting() const = 0; + }; +} + +// end catch_interfaces_runner.h + +#ifdef __OBJC__ +// start catch_objc.hpp + +#import <objc/runtime.h> + +#include <string> + +// NB. Any general catch headers included here must be included +// in catch.hpp first to make sure they are included by the single +// header for non obj-usage + +/////////////////////////////////////////////////////////////////////////////// +// This protocol is really only here for (self) documenting purposes, since +// all its methods are optional. +@protocol OcFixture + +@optional + +-(void) setUp; +-(void) tearDown; + +@end + +namespace Catch { + + class OcMethod : public ITestInvoker { + + public: + OcMethod( Class cls, SEL sel ) : m_cls( cls ), m_sel( sel ) {} + + virtual void invoke() const { + id obj = [[m_cls alloc] init]; + + performOptionalSelector( obj, @selector(setUp) ); + performOptionalSelector( obj, m_sel ); + performOptionalSelector( obj, @selector(tearDown) ); + + arcSafeRelease( obj ); + } + private: + virtual ~OcMethod() {} + + Class m_cls; + SEL m_sel; + }; + + namespace Detail{ + + inline std::string getAnnotation( Class cls, + std::string const& annotationName, + std::string const& testCaseName ) { + NSString* selStr = [[NSString alloc] initWithFormat:@"Catch_%s_%s", annotationName.c_str(), testCaseName.c_str()]; + SEL sel = NSSelectorFromString( selStr ); + arcSafeRelease( selStr ); + id value = performOptionalSelector( cls, sel ); + if( value ) + return [(NSString*)value UTF8String]; + return ""; + } + } + + inline std::size_t registerTestMethods() { + std::size_t noTestMethods = 0; + int noClasses = objc_getClassList( nullptr, 0 ); + + Class* classes = (CATCH_UNSAFE_UNRETAINED Class *)malloc( sizeof(Class) * noClasses); + objc_getClassList( classes, noClasses ); + + for( int c = 0; c < noClasses; c++ ) { + Class cls = classes[c]; + { + u_int count; + Method* methods = class_copyMethodList( cls, &count ); + for( u_int m = 0; m < count ; m++ ) { + SEL selector = method_getName(methods[m]); + std::string methodName = sel_getName(selector); + if( startsWith( methodName, "Catch_TestCase_" ) ) { + std::string testCaseName = methodName.substr( 15 ); + std::string name = Detail::getAnnotation( cls, "Name", testCaseName ); + std::string desc = Detail::getAnnotation( cls, "Description", testCaseName ); + const char* className = class_getName( cls ); + + getMutableRegistryHub().registerTest( makeTestCase( new OcMethod( cls, selector ), className, NameAndTags( name.c_str(), desc.c_str() ), SourceLineInfo("",0) ) ); + noTestMethods++; + } + } + free(methods); + } + } + return noTestMethods; + } + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) + + namespace Matchers { + namespace Impl { + namespace NSStringMatchers { + + struct StringHolder : MatcherBase<NSString*>{ + StringHolder( NSString* substr ) : m_substr( [substr copy] ){} + StringHolder( StringHolder const& other ) : m_substr( [other.m_substr copy] ){} + StringHolder() { + arcSafeRelease( m_substr ); + } + + bool match( NSString* arg ) const override { + return false; + } + + NSString* CATCH_ARC_STRONG m_substr; + }; + + struct Equals : StringHolder { + Equals( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const override { + return (str != nil || m_substr == nil ) && + [str isEqualToString:m_substr]; + } + + std::string describe() const override { + return "equals string: " + Catch::Detail::stringify( m_substr ); + } + }; + + struct Contains : StringHolder { + Contains( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location != NSNotFound; + } + + std::string describe() const override { + return "contains string: " + Catch::Detail::stringify( m_substr ); + } + }; + + struct StartsWith : StringHolder { + StartsWith( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const override { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == 0; + } + + std::string describe() const override { + return "starts with: " + Catch::Detail::stringify( m_substr ); + } + }; + struct EndsWith : StringHolder { + EndsWith( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const override { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == [str length] - [m_substr length]; + } + + std::string describe() const override { + return "ends with: " + Catch::Detail::stringify( m_substr ); + } + }; + + } // namespace NSStringMatchers + } // namespace Impl + + inline Impl::NSStringMatchers::Equals + Equals( NSString* substr ){ return Impl::NSStringMatchers::Equals( substr ); } + + inline Impl::NSStringMatchers::Contains + Contains( NSString* substr ){ return Impl::NSStringMatchers::Contains( substr ); } + + inline Impl::NSStringMatchers::StartsWith + StartsWith( NSString* substr ){ return Impl::NSStringMatchers::StartsWith( substr ); } + + inline Impl::NSStringMatchers::EndsWith + EndsWith( NSString* substr ){ return Impl::NSStringMatchers::EndsWith( substr ); } + + } // namespace Matchers + + using namespace Matchers; + +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +} // namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define OC_MAKE_UNIQUE_NAME( root, uniqueSuffix ) root##uniqueSuffix +#define OC_TEST_CASE2( name, desc, uniqueSuffix ) \ ++(NSString*) OC_MAKE_UNIQUE_NAME( Catch_Name_test_, uniqueSuffix ) \ +{ \ +return @ name; \ +} \ ++(NSString*) OC_MAKE_UNIQUE_NAME( Catch_Description_test_, uniqueSuffix ) \ +{ \ +return @ desc; \ +} \ +-(void) OC_MAKE_UNIQUE_NAME( Catch_TestCase_test_, uniqueSuffix ) + +#define OC_TEST_CASE( name, desc ) OC_TEST_CASE2( name, desc, __LINE__ ) + +// end catch_objc.hpp +#endif + +#ifdef CATCH_CONFIG_EXTERNAL_INTERFACES +// start catch_external_interfaces.h + +// start catch_reporter_bases.hpp + +// start catch_interfaces_reporter.h + +// start catch_config.hpp + +// start catch_test_spec_parser.h + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// start catch_test_spec.h + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// start catch_wildcard_pattern.h + +namespace Catch +{ + class WildcardPattern { + enum WildcardPosition { + NoWildcard = 0, + WildcardAtStart = 1, + WildcardAtEnd = 2, + WildcardAtBothEnds = WildcardAtStart | WildcardAtEnd + }; + + public: + + WildcardPattern( std::string const& pattern, CaseSensitive::Choice caseSensitivity ); + virtual ~WildcardPattern() = default; + virtual bool matches( std::string const& str ) const; + + private: + std::string adjustCase( std::string const& str ) const; + CaseSensitive::Choice m_caseSensitivity; + WildcardPosition m_wildcard = NoWildcard; + std::string m_pattern; + }; +} + +// end catch_wildcard_pattern.h +#include <string> +#include <vector> +#include <memory> + +namespace Catch { + + class TestSpec { + struct Pattern { + virtual ~Pattern(); + virtual bool matches( TestCaseInfo const& testCase ) const = 0; + }; + using PatternPtr = std::shared_ptr<Pattern>; + + class NamePattern : public Pattern { + public: + NamePattern( std::string const& name ); + virtual ~NamePattern(); + virtual bool matches( TestCaseInfo const& testCase ) const override; + private: + WildcardPattern m_wildcardPattern; + }; + + class TagPattern : public Pattern { + public: + TagPattern( std::string const& tag ); + virtual ~TagPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const override; + private: + std::string m_tag; + }; + + class ExcludedPattern : public Pattern { + public: + ExcludedPattern( PatternPtr const& underlyingPattern ); + virtual ~ExcludedPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const override; + private: + PatternPtr m_underlyingPattern; + }; + + struct Filter { + std::vector<PatternPtr> m_patterns; + + bool matches( TestCaseInfo const& testCase ) const; + }; + + public: + bool hasFilters() const; + bool matches( TestCaseInfo const& testCase ) const; + + private: + std::vector<Filter> m_filters; + + friend class TestSpecParser; + }; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_test_spec.h +// start catch_interfaces_tag_alias_registry.h + +#include <string> + +namespace Catch { + + struct TagAlias; + + struct ITagAliasRegistry { + virtual ~ITagAliasRegistry(); + // Nullptr if not present + virtual TagAlias const* find( std::string const& alias ) const = 0; + virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const = 0; + + static ITagAliasRegistry const& get(); + }; + +} // end namespace Catch + +// end catch_interfaces_tag_alias_registry.h +namespace Catch { + + class TestSpecParser { + enum Mode{ None, Name, QuotedName, Tag, EscapedName }; + Mode m_mode = None; + bool m_exclusion = false; + std::size_t m_start = std::string::npos, m_pos = 0; + std::string m_arg; + std::vector<std::size_t> m_escapeChars; + TestSpec::Filter m_currentFilter; + TestSpec m_testSpec; + ITagAliasRegistry const* m_tagAliases = nullptr; + + public: + TestSpecParser( ITagAliasRegistry const& tagAliases ); + + TestSpecParser& parse( std::string const& arg ); + TestSpec testSpec(); + + private: + void visitChar( char c ); + void startNewMode( Mode mode, std::size_t start ); + void escape(); + std::string subString() const; + + template<typename T> + void addPattern() { + std::string token = subString(); + for( std::size_t i = 0; i < m_escapeChars.size(); ++i ) + token = token.substr( 0, m_escapeChars[i]-m_start-i ) + token.substr( m_escapeChars[i]-m_start-i+1 ); + m_escapeChars.clear(); + if( startsWith( token, "exclude:" ) ) { + m_exclusion = true; + token = token.substr( 8 ); + } + if( !token.empty() ) { + TestSpec::PatternPtr pattern = std::make_shared<T>( token ); + if( m_exclusion ) + pattern = std::make_shared<TestSpec::ExcludedPattern>( pattern ); + m_currentFilter.m_patterns.push_back( pattern ); + } + m_exclusion = false; + m_mode = None; + } + + void addFilter(); + }; + TestSpec parseTestSpec( std::string const& arg ); + +} // namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_test_spec_parser.h +// Libstdc++ doesn't like incomplete classes for unique_ptr + +#include <memory> +#include <vector> +#include <string> + +#ifndef CATCH_CONFIG_CONSOLE_WIDTH +#define CATCH_CONFIG_CONSOLE_WIDTH 80 +#endif + +namespace Catch { + + struct IStream; + + struct ConfigData { + bool listTests = false; + bool listTags = false; + bool listReporters = false; + bool listTestNamesOnly = false; + + bool showSuccessfulTests = false; + bool shouldDebugBreak = false; + bool noThrow = false; + bool showHelp = false; + bool showInvisibles = false; + bool filenamesAsTags = false; + bool libIdentify = false; + + int abortAfter = -1; + unsigned int rngSeed = 0; + int benchmarkResolutionMultiple = 100; + + Verbosity verbosity = Verbosity::Normal; + WarnAbout::What warnings = WarnAbout::Nothing; + ShowDurations::OrNot showDurations = ShowDurations::DefaultForReporter; + RunTests::InWhatOrder runOrder = RunTests::InDeclarationOrder; + UseColour::YesOrNo useColour = UseColour::Auto; + WaitForKeypress::When waitForKeypress = WaitForKeypress::Never; + + std::string outputFilename; + std::string name; + std::string processName; +#ifndef CATCH_CONFIG_DEFAULT_REPORTER +#define CATCH_CONFIG_DEFAULT_REPORTER "console" +#endif + std::string reporterName = CATCH_CONFIG_DEFAULT_REPORTER; +#undef CATCH_CONFIG_DEFAULT_REPORTER + + std::vector<std::string> testsOrTags; + std::vector<std::string> sectionsToRun; + }; + + class Config : public IConfig { + public: + + Config() = default; + Config( ConfigData const& data ); + virtual ~Config() = default; + + std::string const& getFilename() const; + + bool listTests() const; + bool listTestNamesOnly() const; + bool listTags() const; + bool listReporters() const; + + std::string getProcessName() const; + std::string const& getReporterName() const; + + std::vector<std::string> const& getTestsOrTags() const; + std::vector<std::string> const& getSectionsToRun() const override; + + virtual TestSpec const& testSpec() const override; + bool hasTestFilters() const override; + + bool showHelp() const; + + // IConfig interface + bool allowThrows() const override; + std::ostream& stream() const override; + std::string name() const override; + bool includeSuccessfulResults() const override; + bool warnAboutMissingAssertions() const override; + bool warnAboutNoTests() const override; + ShowDurations::OrNot showDurations() const override; + RunTests::InWhatOrder runOrder() const override; + unsigned int rngSeed() const override; + int benchmarkResolutionMultiple() const override; + UseColour::YesOrNo useColour() const override; + bool shouldDebugBreak() const override; + int abortAfter() const override; + bool showInvisibles() const override; + Verbosity verbosity() const override; + + private: + + IStream const* openStream(); + ConfigData m_data; + + std::unique_ptr<IStream const> m_stream; + TestSpec m_testSpec; + bool m_hasTestFilters = false; + }; + +} // end namespace Catch + +// end catch_config.hpp +// start catch_assertionresult.h + +#include <string> + +namespace Catch { + + struct AssertionResultData + { + AssertionResultData() = delete; + + AssertionResultData( ResultWas::OfType _resultType, LazyExpression const& _lazyExpression ); + + std::string message; + mutable std::string reconstructedExpression; + LazyExpression lazyExpression; + ResultWas::OfType resultType; + + std::string reconstructExpression() const; + }; + + class AssertionResult { + public: + AssertionResult() = delete; + AssertionResult( AssertionInfo const& info, AssertionResultData const& data ); + + bool isOk() const; + bool succeeded() const; + ResultWas::OfType getResultType() const; + bool hasExpression() const; + bool hasMessage() const; + std::string getExpression() const; + std::string getExpressionInMacro() const; + bool hasExpandedExpression() const; + std::string getExpandedExpression() const; + std::string getMessage() const; + SourceLineInfo getSourceInfo() const; + StringRef getTestMacroName() const; + + //protected: + AssertionInfo m_info; + AssertionResultData m_resultData; + }; + +} // end namespace Catch + +// end catch_assertionresult.h +// start catch_option.hpp + +namespace Catch { + + // An optional type + template<typename T> + class Option { + public: + Option() : nullableValue( nullptr ) {} + Option( T const& _value ) + : nullableValue( new( storage ) T( _value ) ) + {} + Option( Option const& _other ) + : nullableValue( _other ? new( storage ) T( *_other ) : nullptr ) + {} + + ~Option() { + reset(); + } + + Option& operator= ( Option const& _other ) { + if( &_other != this ) { + reset(); + if( _other ) + nullableValue = new( storage ) T( *_other ); + } + return *this; + } + Option& operator = ( T const& _value ) { + reset(); + nullableValue = new( storage ) T( _value ); + return *this; + } + + void reset() { + if( nullableValue ) + nullableValue->~T(); + nullableValue = nullptr; + } + + T& operator*() { return *nullableValue; } + T const& operator*() const { return *nullableValue; } + T* operator->() { return nullableValue; } + const T* operator->() const { return nullableValue; } + + T valueOr( T const& defaultValue ) const { + return nullableValue ? *nullableValue : defaultValue; + } + + bool some() const { return nullableValue != nullptr; } + bool none() const { return nullableValue == nullptr; } + + bool operator !() const { return nullableValue == nullptr; } + explicit operator bool() const { + return some(); + } + + private: + T *nullableValue; + alignas(alignof(T)) char storage[sizeof(T)]; + }; + +} // end namespace Catch + +// end catch_option.hpp +#include <string> +#include <iosfwd> +#include <map> +#include <set> +#include <memory> + +namespace Catch { + + struct ReporterConfig { + explicit ReporterConfig( IConfigPtr const& _fullConfig ); + + ReporterConfig( IConfigPtr const& _fullConfig, std::ostream& _stream ); + + std::ostream& stream() const; + IConfigPtr fullConfig() const; + + private: + std::ostream* m_stream; + IConfigPtr m_fullConfig; + }; + + struct ReporterPreferences { + bool shouldRedirectStdOut = false; + bool shouldReportAllAssertions = false; + }; + + template<typename T> + struct LazyStat : Option<T> { + LazyStat& operator=( T const& _value ) { + Option<T>::operator=( _value ); + used = false; + return *this; + } + void reset() { + Option<T>::reset(); + used = false; + } + bool used = false; + }; + + struct TestRunInfo { + TestRunInfo( std::string const& _name ); + std::string name; + }; + struct GroupInfo { + GroupInfo( std::string const& _name, + std::size_t _groupIndex, + std::size_t _groupsCount ); + + std::string name; + std::size_t groupIndex; + std::size_t groupsCounts; + }; + + struct AssertionStats { + AssertionStats( AssertionResult const& _assertionResult, + std::vector<MessageInfo> const& _infoMessages, + Totals const& _totals ); + + AssertionStats( AssertionStats const& ) = default; + AssertionStats( AssertionStats && ) = default; + AssertionStats& operator = ( AssertionStats const& ) = delete; + AssertionStats& operator = ( AssertionStats && ) = delete; + virtual ~AssertionStats(); + + AssertionResult assertionResult; + std::vector<MessageInfo> infoMessages; + Totals totals; + }; + + struct SectionStats { + SectionStats( SectionInfo const& _sectionInfo, + Counts const& _assertions, + double _durationInSeconds, + bool _missingAssertions ); + SectionStats( SectionStats const& ) = default; + SectionStats( SectionStats && ) = default; + SectionStats& operator = ( SectionStats const& ) = default; + SectionStats& operator = ( SectionStats && ) = default; + virtual ~SectionStats(); + + SectionInfo sectionInfo; + Counts assertions; + double durationInSeconds; + bool missingAssertions; + }; + + struct TestCaseStats { + TestCaseStats( TestCaseInfo const& _testInfo, + Totals const& _totals, + std::string const& _stdOut, + std::string const& _stdErr, + bool _aborting ); + + TestCaseStats( TestCaseStats const& ) = default; + TestCaseStats( TestCaseStats && ) = default; + TestCaseStats& operator = ( TestCaseStats const& ) = default; + TestCaseStats& operator = ( TestCaseStats && ) = default; + virtual ~TestCaseStats(); + + TestCaseInfo testInfo; + Totals totals; + std::string stdOut; + std::string stdErr; + bool aborting; + }; + + struct TestGroupStats { + TestGroupStats( GroupInfo const& _groupInfo, + Totals const& _totals, + bool _aborting ); + TestGroupStats( GroupInfo const& _groupInfo ); + + TestGroupStats( TestGroupStats const& ) = default; + TestGroupStats( TestGroupStats && ) = default; + TestGroupStats& operator = ( TestGroupStats const& ) = default; + TestGroupStats& operator = ( TestGroupStats && ) = default; + virtual ~TestGroupStats(); + + GroupInfo groupInfo; + Totals totals; + bool aborting; + }; + + struct TestRunStats { + TestRunStats( TestRunInfo const& _runInfo, + Totals const& _totals, + bool _aborting ); + + TestRunStats( TestRunStats const& ) = default; + TestRunStats( TestRunStats && ) = default; + TestRunStats& operator = ( TestRunStats const& ) = default; + TestRunStats& operator = ( TestRunStats && ) = default; + virtual ~TestRunStats(); + + TestRunInfo runInfo; + Totals totals; + bool aborting; + }; + + struct BenchmarkInfo { + std::string name; + }; + struct BenchmarkStats { + BenchmarkInfo info; + std::size_t iterations; + uint64_t elapsedTimeInNanoseconds; + }; + + struct IStreamingReporter { + virtual ~IStreamingReporter() = default; + + // Implementing class must also provide the following static methods: + // static std::string getDescription(); + // static std::set<Verbosity> getSupportedVerbosities() + + virtual ReporterPreferences getPreferences() const = 0; + + virtual void noMatchingTestCases( std::string const& spec ) = 0; + + virtual void testRunStarting( TestRunInfo const& testRunInfo ) = 0; + virtual void testGroupStarting( GroupInfo const& groupInfo ) = 0; + + virtual void testCaseStarting( TestCaseInfo const& testInfo ) = 0; + virtual void sectionStarting( SectionInfo const& sectionInfo ) = 0; + + // *** experimental *** + virtual void benchmarkStarting( BenchmarkInfo const& ) {} + + virtual void assertionStarting( AssertionInfo const& assertionInfo ) = 0; + + // The return value indicates if the messages buffer should be cleared: + virtual bool assertionEnded( AssertionStats const& assertionStats ) = 0; + + // *** experimental *** + virtual void benchmarkEnded( BenchmarkStats const& ) {} + + virtual void sectionEnded( SectionStats const& sectionStats ) = 0; + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) = 0; + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) = 0; + virtual void testRunEnded( TestRunStats const& testRunStats ) = 0; + + virtual void skipTest( TestCaseInfo const& testInfo ) = 0; + + // Default empty implementation provided + virtual void fatalErrorEncountered( StringRef name ); + + virtual bool isMulti() const; + }; + using IStreamingReporterPtr = std::unique_ptr<IStreamingReporter>; + + struct IReporterFactory { + virtual ~IReporterFactory(); + virtual IStreamingReporterPtr create( ReporterConfig const& config ) const = 0; + virtual std::string getDescription() const = 0; + }; + using IReporterFactoryPtr = std::shared_ptr<IReporterFactory>; + + struct IReporterRegistry { + using FactoryMap = std::map<std::string, IReporterFactoryPtr>; + using Listeners = std::vector<IReporterFactoryPtr>; + + virtual ~IReporterRegistry(); + virtual IStreamingReporterPtr create( std::string const& name, IConfigPtr const& config ) const = 0; + virtual FactoryMap const& getFactories() const = 0; + virtual Listeners const& getListeners() const = 0; + }; + +} // end namespace Catch + +// end catch_interfaces_reporter.h +#include <algorithm> +#include <cstring> +#include <cfloat> +#include <cstdio> +#include <cassert> +#include <memory> +#include <ostream> + +namespace Catch { + void prepareExpandedExpression(AssertionResult& result); + + // Returns double formatted as %.3f (format expected on output) + std::string getFormattedDuration( double duration ); + + template<typename DerivedT> + struct StreamingReporterBase : IStreamingReporter { + + StreamingReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = false; + if( !DerivedT::getSupportedVerbosities().count( m_config->verbosity() ) ) + CATCH_ERROR( "Verbosity level not supported by this reporter" ); + } + + ReporterPreferences getPreferences() const override { + return m_reporterPrefs; + } + + static std::set<Verbosity> getSupportedVerbosities() { + return { Verbosity::Normal }; + } + + ~StreamingReporterBase() override = default; + + void noMatchingTestCases(std::string const&) override {} + + void testRunStarting(TestRunInfo const& _testRunInfo) override { + currentTestRunInfo = _testRunInfo; + } + void testGroupStarting(GroupInfo const& _groupInfo) override { + currentGroupInfo = _groupInfo; + } + + void testCaseStarting(TestCaseInfo const& _testInfo) override { + currentTestCaseInfo = _testInfo; + } + void sectionStarting(SectionInfo const& _sectionInfo) override { + m_sectionStack.push_back(_sectionInfo); + } + + void sectionEnded(SectionStats const& /* _sectionStats */) override { + m_sectionStack.pop_back(); + } + void testCaseEnded(TestCaseStats const& /* _testCaseStats */) override { + currentTestCaseInfo.reset(); + } + void testGroupEnded(TestGroupStats const& /* _testGroupStats */) override { + currentGroupInfo.reset(); + } + void testRunEnded(TestRunStats const& /* _testRunStats */) override { + currentTestCaseInfo.reset(); + currentGroupInfo.reset(); + currentTestRunInfo.reset(); + } + + void skipTest(TestCaseInfo const&) override { + // Don't do anything with this by default. + // It can optionally be overridden in the derived class. + } + + IConfigPtr m_config; + std::ostream& stream; + + LazyStat<TestRunInfo> currentTestRunInfo; + LazyStat<GroupInfo> currentGroupInfo; + LazyStat<TestCaseInfo> currentTestCaseInfo; + + std::vector<SectionInfo> m_sectionStack; + ReporterPreferences m_reporterPrefs; + }; + + template<typename DerivedT> + struct CumulativeReporterBase : IStreamingReporter { + template<typename T, typename ChildNodeT> + struct Node { + explicit Node( T const& _value ) : value( _value ) {} + virtual ~Node() {} + + using ChildNodes = std::vector<std::shared_ptr<ChildNodeT>>; + T value; + ChildNodes children; + }; + struct SectionNode { + explicit SectionNode(SectionStats const& _stats) : stats(_stats) {} + virtual ~SectionNode() = default; + + bool operator == (SectionNode const& other) const { + return stats.sectionInfo.lineInfo == other.stats.sectionInfo.lineInfo; + } + bool operator == (std::shared_ptr<SectionNode> const& other) const { + return operator==(*other); + } + + SectionStats stats; + using ChildSections = std::vector<std::shared_ptr<SectionNode>>; + using Assertions = std::vector<AssertionStats>; + ChildSections childSections; + Assertions assertions; + std::string stdOut; + std::string stdErr; + }; + + struct BySectionInfo { + BySectionInfo( SectionInfo const& other ) : m_other( other ) {} + BySectionInfo( BySectionInfo const& other ) : m_other( other.m_other ) {} + bool operator() (std::shared_ptr<SectionNode> const& node) const { + return ((node->stats.sectionInfo.name == m_other.name) && + (node->stats.sectionInfo.lineInfo == m_other.lineInfo)); + } + void operator=(BySectionInfo const&) = delete; + + private: + SectionInfo const& m_other; + }; + + using TestCaseNode = Node<TestCaseStats, SectionNode>; + using TestGroupNode = Node<TestGroupStats, TestCaseNode>; + using TestRunNode = Node<TestRunStats, TestGroupNode>; + + CumulativeReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = false; + if( !DerivedT::getSupportedVerbosities().count( m_config->verbosity() ) ) + CATCH_ERROR( "Verbosity level not supported by this reporter" ); + } + ~CumulativeReporterBase() override = default; + + ReporterPreferences getPreferences() const override { + return m_reporterPrefs; + } + + static std::set<Verbosity> getSupportedVerbosities() { + return { Verbosity::Normal }; + } + + void testRunStarting( TestRunInfo const& ) override {} + void testGroupStarting( GroupInfo const& ) override {} + + void testCaseStarting( TestCaseInfo const& ) override {} + + void sectionStarting( SectionInfo const& sectionInfo ) override { + SectionStats incompleteStats( sectionInfo, Counts(), 0, false ); + std::shared_ptr<SectionNode> node; + if( m_sectionStack.empty() ) { + if( !m_rootSection ) + m_rootSection = std::make_shared<SectionNode>( incompleteStats ); + node = m_rootSection; + } + else { + SectionNode& parentNode = *m_sectionStack.back(); + auto it = + std::find_if( parentNode.childSections.begin(), + parentNode.childSections.end(), + BySectionInfo( sectionInfo ) ); + if( it == parentNode.childSections.end() ) { + node = std::make_shared<SectionNode>( incompleteStats ); + parentNode.childSections.push_back( node ); + } + else + node = *it; + } + m_sectionStack.push_back( node ); + m_deepestSection = std::move(node); + } + + void assertionStarting(AssertionInfo const&) override {} + + bool assertionEnded(AssertionStats const& assertionStats) override { + assert(!m_sectionStack.empty()); + // AssertionResult holds a pointer to a temporary DecomposedExpression, + // which getExpandedExpression() calls to build the expression string. + // Our section stack copy of the assertionResult will likely outlive the + // temporary, so it must be expanded or discarded now to avoid calling + // a destroyed object later. + prepareExpandedExpression(const_cast<AssertionResult&>( assertionStats.assertionResult ) ); + SectionNode& sectionNode = *m_sectionStack.back(); + sectionNode.assertions.push_back(assertionStats); + return true; + } + void sectionEnded(SectionStats const& sectionStats) override { + assert(!m_sectionStack.empty()); + SectionNode& node = *m_sectionStack.back(); + node.stats = sectionStats; + m_sectionStack.pop_back(); + } + void testCaseEnded(TestCaseStats const& testCaseStats) override { + auto node = std::make_shared<TestCaseNode>(testCaseStats); + assert(m_sectionStack.size() == 0); + node->children.push_back(m_rootSection); + m_testCases.push_back(node); + m_rootSection.reset(); + + assert(m_deepestSection); + m_deepestSection->stdOut = testCaseStats.stdOut; + m_deepestSection->stdErr = testCaseStats.stdErr; + } + void testGroupEnded(TestGroupStats const& testGroupStats) override { + auto node = std::make_shared<TestGroupNode>(testGroupStats); + node->children.swap(m_testCases); + m_testGroups.push_back(node); + } + void testRunEnded(TestRunStats const& testRunStats) override { + auto node = std::make_shared<TestRunNode>(testRunStats); + node->children.swap(m_testGroups); + m_testRuns.push_back(node); + testRunEndedCumulative(); + } + virtual void testRunEndedCumulative() = 0; + + void skipTest(TestCaseInfo const&) override {} + + IConfigPtr m_config; + std::ostream& stream; + std::vector<AssertionStats> m_assertions; + std::vector<std::vector<std::shared_ptr<SectionNode>>> m_sections; + std::vector<std::shared_ptr<TestCaseNode>> m_testCases; + std::vector<std::shared_ptr<TestGroupNode>> m_testGroups; + + std::vector<std::shared_ptr<TestRunNode>> m_testRuns; + + std::shared_ptr<SectionNode> m_rootSection; + std::shared_ptr<SectionNode> m_deepestSection; + std::vector<std::shared_ptr<SectionNode>> m_sectionStack; + ReporterPreferences m_reporterPrefs; + }; + + template<char C> + char const* getLineOfChars() { + static char line[CATCH_CONFIG_CONSOLE_WIDTH] = {0}; + if( !*line ) { + std::memset( line, C, CATCH_CONFIG_CONSOLE_WIDTH-1 ); + line[CATCH_CONFIG_CONSOLE_WIDTH-1] = 0; + } + return line; + } + + struct TestEventListenerBase : StreamingReporterBase<TestEventListenerBase> { + TestEventListenerBase( ReporterConfig const& _config ); + + static std::set<Verbosity> getSupportedVerbosities(); + + void assertionStarting(AssertionInfo const&) override; + bool assertionEnded(AssertionStats const&) override; + }; + +} // end namespace Catch + +// end catch_reporter_bases.hpp +// start catch_console_colour.h + +namespace Catch { + + struct Colour { + enum Code { + None = 0, + + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White, + BrightYellow = Bright | Yellow, + + // By intention + FileName = LightGrey, + Warning = BrightYellow, + ResultError = BrightRed, + ResultSuccess = BrightGreen, + ResultExpectedFailure = Warning, + + Error = BrightRed, + Success = Green, + + OriginalExpression = Cyan, + ReconstructedExpression = BrightYellow, + + SecondaryText = LightGrey, + Headers = White + }; + + // Use constructed object for RAII guard + Colour( Code _colourCode ); + Colour( Colour&& other ) noexcept; + Colour& operator=( Colour&& other ) noexcept; + ~Colour(); + + // Use static method for one-shot changes + static void use( Code _colourCode ); + + private: + bool m_moved = false; + }; + + std::ostream& operator << ( std::ostream& os, Colour const& ); + +} // end namespace Catch + +// end catch_console_colour.h +// start catch_reporter_registrars.hpp + + +namespace Catch { + + template<typename T> + class ReporterRegistrar { + + class ReporterFactory : public IReporterFactory { + + virtual IStreamingReporterPtr create( ReporterConfig const& config ) const override { + return std::unique_ptr<T>( new T( config ) ); + } + + virtual std::string getDescription() const override { + return T::getDescription(); + } + }; + + public: + + explicit ReporterRegistrar( std::string const& name ) { + getMutableRegistryHub().registerReporter( name, std::make_shared<ReporterFactory>() ); + } + }; + + template<typename T> + class ListenerRegistrar { + + class ListenerFactory : public IReporterFactory { + + virtual IStreamingReporterPtr create( ReporterConfig const& config ) const override { + return std::unique_ptr<T>( new T( config ) ); + } + virtual std::string getDescription() const override { + return std::string(); + } + }; + + public: + + ListenerRegistrar() { + getMutableRegistryHub().registerListener( std::make_shared<ListenerFactory>() ); + } + }; +} + +#if !defined(CATCH_CONFIG_DISABLE) + +#define CATCH_REGISTER_REPORTER( name, reporterType ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::ReporterRegistrar<reporterType> catch_internal_RegistrarFor##reporterType( name ); } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + +#define CATCH_REGISTER_LISTENER( listenerType ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::ListenerRegistrar<listenerType> catch_internal_RegistrarFor##listenerType; } \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS +#else // CATCH_CONFIG_DISABLE + +#define CATCH_REGISTER_REPORTER(name, reporterType) +#define CATCH_REGISTER_LISTENER(listenerType) + +#endif // CATCH_CONFIG_DISABLE + +// end catch_reporter_registrars.hpp +// Allow users to base their work off existing reporters +// start catch_reporter_compact.h + +namespace Catch { + + struct CompactReporter : StreamingReporterBase<CompactReporter> { + + using StreamingReporterBase::StreamingReporterBase; + + ~CompactReporter() override; + + static std::string getDescription(); + + ReporterPreferences getPreferences() const override; + + void noMatchingTestCases(std::string const& spec) override; + + void assertionStarting(AssertionInfo const&) override; + + bool assertionEnded(AssertionStats const& _assertionStats) override; + + void sectionEnded(SectionStats const& _sectionStats) override; + + void testRunEnded(TestRunStats const& _testRunStats) override; + + }; + +} // end namespace Catch + +// end catch_reporter_compact.h +// start catch_reporter_console.h + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch + // Note that 4062 (not all labels are handled + // and default is missing) is enabled +#endif + +namespace Catch { + // Fwd decls + struct SummaryColumn; + class TablePrinter; + + struct ConsoleReporter : StreamingReporterBase<ConsoleReporter> { + std::unique_ptr<TablePrinter> m_tablePrinter; + + ConsoleReporter(ReporterConfig const& config); + ~ConsoleReporter() override; + static std::string getDescription(); + + void noMatchingTestCases(std::string const& spec) override; + + void assertionStarting(AssertionInfo const&) override; + + bool assertionEnded(AssertionStats const& _assertionStats) override; + + void sectionStarting(SectionInfo const& _sectionInfo) override; + void sectionEnded(SectionStats const& _sectionStats) override; + + void benchmarkStarting(BenchmarkInfo const& info) override; + void benchmarkEnded(BenchmarkStats const& stats) override; + + void testCaseEnded(TestCaseStats const& _testCaseStats) override; + void testGroupEnded(TestGroupStats const& _testGroupStats) override; + void testRunEnded(TestRunStats const& _testRunStats) override; + + private: + + void lazyPrint(); + + void lazyPrintWithoutClosingBenchmarkTable(); + void lazyPrintRunInfo(); + void lazyPrintGroupInfo(); + void printTestCaseAndSectionHeader(); + + void printClosedHeader(std::string const& _name); + void printOpenHeader(std::string const& _name); + + // if string has a : in first line will set indent to follow it on + // subsequent lines + void printHeaderString(std::string const& _string, std::size_t indent = 0); + + void printTotals(Totals const& totals); + void printSummaryRow(std::string const& label, std::vector<SummaryColumn> const& cols, std::size_t row); + + void printTotalsDivider(Totals const& totals); + void printSummaryDivider(); + + private: + bool m_headerPrinted = false; + }; + +} // end namespace Catch + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +// end catch_reporter_console.h +// start catch_reporter_junit.h + +// start catch_xmlwriter.h + +#include <vector> + +namespace Catch { + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); + + void encodeTo( std::ostream& os ) const; + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ); + + ScopedElement( ScopedElement&& other ) noexcept; + ScopedElement& operator=( ScopedElement&& other ) noexcept; + + ~ScopedElement(); + + ScopedElement& writeText( std::string const& text, bool indent = true ); + + template<typename T> + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer = nullptr; + }; + + XmlWriter( std::ostream& os = Catch::cout() ); + ~XmlWriter(); + + XmlWriter( XmlWriter const& ) = delete; + XmlWriter& operator=( XmlWriter const& ) = delete; + + XmlWriter& startElement( std::string const& name ); + + ScopedElement scopedElement( std::string const& name ); + + XmlWriter& endElement(); + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); + + XmlWriter& writeAttribute( std::string const& name, bool attribute ); + + template<typename T> + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + ReusableStringStream rss; + rss << attribute; + return writeAttribute( name, rss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ); + + XmlWriter& writeComment( std::string const& text ); + + void writeStylesheetRef( std::string const& url ); + + XmlWriter& writeBlankLine(); + + void ensureTagClosed(); + + private: + + void writeDeclaration(); + + void newlineIfNecessary(); + + bool m_tagIsOpen = false; + bool m_needsNewline = false; + std::vector<std::string> m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +} + +// end catch_xmlwriter.h +namespace Catch { + + class JunitReporter : public CumulativeReporterBase<JunitReporter> { + public: + JunitReporter(ReporterConfig const& _config); + + ~JunitReporter() override; + + static std::string getDescription(); + + void noMatchingTestCases(std::string const& /*spec*/) override; + + void testRunStarting(TestRunInfo const& runInfo) override; + + void testGroupStarting(GroupInfo const& groupInfo) override; + + void testCaseStarting(TestCaseInfo const& testCaseInfo) override; + bool assertionEnded(AssertionStats const& assertionStats) override; + + void testCaseEnded(TestCaseStats const& testCaseStats) override; + + void testGroupEnded(TestGroupStats const& testGroupStats) override; + + void testRunEndedCumulative() override; + + void writeGroup(TestGroupNode const& groupNode, double suiteTime); + + void writeTestCase(TestCaseNode const& testCaseNode); + + void writeSection(std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode); + + void writeAssertions(SectionNode const& sectionNode); + void writeAssertion(AssertionStats const& stats); + + XmlWriter xml; + Timer suiteTimer; + std::string stdOutForSuite; + std::string stdErrForSuite; + unsigned int unexpectedExceptions = 0; + bool m_okToFail = false; + }; + +} // end namespace Catch + +// end catch_reporter_junit.h +// start catch_reporter_xml.h + +namespace Catch { + class XmlReporter : public StreamingReporterBase<XmlReporter> { + public: + XmlReporter(ReporterConfig const& _config); + + ~XmlReporter() override; + + static std::string getDescription(); + + virtual std::string getStylesheetRef() const; + + void writeSourceInfo(SourceLineInfo const& sourceInfo); + + public: // StreamingReporterBase + + void noMatchingTestCases(std::string const& s) override; + + void testRunStarting(TestRunInfo const& testInfo) override; + + void testGroupStarting(GroupInfo const& groupInfo) override; + + void testCaseStarting(TestCaseInfo const& testInfo) override; + + void sectionStarting(SectionInfo const& sectionInfo) override; + + void assertionStarting(AssertionInfo const&) override; + + bool assertionEnded(AssertionStats const& assertionStats) override; + + void sectionEnded(SectionStats const& sectionStats) override; + + void testCaseEnded(TestCaseStats const& testCaseStats) override; + + void testGroupEnded(TestGroupStats const& testGroupStats) override; + + void testRunEnded(TestRunStats const& testRunStats) override; + + private: + Timer m_testCaseTimer; + XmlWriter m_xml; + int m_sectionDepth = 0; + }; + +} // end namespace Catch + +// end catch_reporter_xml.h + +// end catch_external_interfaces.h +#endif + +#endif // ! CATCH_CONFIG_IMPL_ONLY + +#ifdef CATCH_IMPL +// start catch_impl.hpp + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wweak-vtables" +#endif + +// Keep these here for external reporters +// start catch_test_case_tracker.h + +#include <string> +#include <vector> +#include <memory> + +namespace Catch { +namespace TestCaseTracking { + + struct NameAndLocation { + std::string name; + SourceLineInfo location; + + NameAndLocation( std::string const& _name, SourceLineInfo const& _location ); + }; + + struct ITracker; + + using ITrackerPtr = std::shared_ptr<ITracker>; + + struct ITracker { + virtual ~ITracker(); + + // static queries + virtual NameAndLocation const& nameAndLocation() const = 0; + + // dynamic queries + virtual bool isComplete() const = 0; // Successfully completed or failed + virtual bool isSuccessfullyCompleted() const = 0; + virtual bool isOpen() const = 0; // Started but not complete + virtual bool hasChildren() const = 0; + + virtual ITracker& parent() = 0; + + // actions + virtual void close() = 0; // Successfully complete + virtual void fail() = 0; + virtual void markAsNeedingAnotherRun() = 0; + + virtual void addChild( ITrackerPtr const& child ) = 0; + virtual ITrackerPtr findChild( NameAndLocation const& nameAndLocation ) = 0; + virtual void openChild() = 0; + + // Debug/ checking + virtual bool isSectionTracker() const = 0; + virtual bool isGeneratorTracker() const = 0; + }; + + class TrackerContext { + + enum RunState { + NotStarted, + Executing, + CompletedCycle + }; + + ITrackerPtr m_rootTracker; + ITracker* m_currentTracker = nullptr; + RunState m_runState = NotStarted; + + public: + + static TrackerContext& instance(); + + ITracker& startRun(); + void endRun(); + + void startCycle(); + void completeCycle(); + + bool completedCycle() const; + ITracker& currentTracker(); + void setCurrentTracker( ITracker* tracker ); + }; + + class TrackerBase : public ITracker { + protected: + enum CycleState { + NotStarted, + Executing, + ExecutingChildren, + NeedsAnotherRun, + CompletedSuccessfully, + Failed + }; + + using Children = std::vector<ITrackerPtr>; + NameAndLocation m_nameAndLocation; + TrackerContext& m_ctx; + ITracker* m_parent; + Children m_children; + CycleState m_runState = NotStarted; + + public: + TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ); + + NameAndLocation const& nameAndLocation() const override; + bool isComplete() const override; + bool isSuccessfullyCompleted() const override; + bool isOpen() const override; + bool hasChildren() const override; + + void addChild( ITrackerPtr const& child ) override; + + ITrackerPtr findChild( NameAndLocation const& nameAndLocation ) override; + ITracker& parent() override; + + void openChild() override; + + bool isSectionTracker() const override; + bool isGeneratorTracker() const override; + + void open(); + + void close() override; + void fail() override; + void markAsNeedingAnotherRun() override; + + private: + void moveToParent(); + void moveToThis(); + }; + + class SectionTracker : public TrackerBase { + std::vector<std::string> m_filters; + public: + SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ); + + bool isSectionTracker() const override; + + bool isComplete() const override; + + static SectionTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ); + + void tryOpen(); + + void addInitialFilters( std::vector<std::string> const& filters ); + void addNextFilters( std::vector<std::string> const& filters ); + }; + +} // namespace TestCaseTracking + +using TestCaseTracking::ITracker; +using TestCaseTracking::TrackerContext; +using TestCaseTracking::SectionTracker; + +} // namespace Catch + +// end catch_test_case_tracker.h + +// start catch_leak_detector.h + +namespace Catch { + + struct LeakDetector { + LeakDetector(); + ~LeakDetector(); + }; + +} +// end catch_leak_detector.h +// Cpp files will be included in the single-header file here +// start catch_approx.cpp + +#include <cmath> +#include <limits> + +namespace { + +// Performs equivalent check of std::fabs(lhs - rhs) <= margin +// But without the subtraction to allow for INFINITY in comparison +bool marginComparison(double lhs, double rhs, double margin) { + return (lhs + margin >= rhs) && (rhs + margin >= lhs); +} + +} + +namespace Catch { +namespace Detail { + + Approx::Approx ( double value ) + : m_epsilon( std::numeric_limits<float>::epsilon()*100 ), + m_margin( 0.0 ), + m_scale( 0.0 ), + m_value( value ) + {} + + Approx Approx::custom() { + return Approx( 0 ); + } + + Approx Approx::operator-() const { + auto temp(*this); + temp.m_value = -temp.m_value; + return temp; + } + + std::string Approx::toString() const { + ReusableStringStream rss; + rss << "Approx( " << ::Catch::Detail::stringify( m_value ) << " )"; + return rss.str(); + } + + bool Approx::equalityComparisonImpl(const double other) const { + // First try with fixed margin, then compute margin based on epsilon, scale and Approx's value + // Thanks to Richard Harris for his help refining the scaled margin value + return marginComparison(m_value, other, m_margin) || marginComparison(m_value, other, m_epsilon * (m_scale + std::fabs(m_value))); + } + + void Approx::setMargin(double margin) { + CATCH_ENFORCE(margin >= 0, + "Invalid Approx::margin: " << margin << '.' + << " Approx::Margin has to be non-negative."); + m_margin = margin; + } + + void Approx::setEpsilon(double epsilon) { + CATCH_ENFORCE(epsilon >= 0 && epsilon <= 1.0, + "Invalid Approx::epsilon: " << epsilon << '.' + << " Approx::epsilon has to be in [0, 1]"); + m_epsilon = epsilon; + } + +} // end namespace Detail + +namespace literals { + Detail::Approx operator "" _a(long double val) { + return Detail::Approx(val); + } + Detail::Approx operator "" _a(unsigned long long val) { + return Detail::Approx(val); + } +} // end namespace literals + +std::string StringMaker<Catch::Detail::Approx>::convert(Catch::Detail::Approx const& value) { + return value.toString(); +} + +} // end namespace Catch +// end catch_approx.cpp +// start catch_assertionhandler.cpp + +// start catch_debugger.h + +namespace Catch { + bool isDebuggerActive(); +} + +#ifdef CATCH_PLATFORM_MAC + + #define CATCH_TRAP() __asm__("int $3\n" : : ) /* NOLINT */ + +#elif defined(CATCH_PLATFORM_LINUX) + // If we can use inline assembler, do it because this allows us to break + // directly at the location of the failing check instead of breaking inside + // raise() called from it, i.e. one stack frame below. + #if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) + #define CATCH_TRAP() asm volatile ("int $3") /* NOLINT */ + #else // Fall back to the generic way. + #include <signal.h> + + #define CATCH_TRAP() raise(SIGTRAP) + #endif +#elif defined(_MSC_VER) + #define CATCH_TRAP() __debugbreak() +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) void __stdcall DebugBreak(); + #define CATCH_TRAP() DebugBreak() +#endif + +#ifdef CATCH_TRAP + #define CATCH_BREAK_INTO_DEBUGGER() []{ if( Catch::isDebuggerActive() ) { CATCH_TRAP(); } }() +#else + #define CATCH_BREAK_INTO_DEBUGGER() []{}() +#endif + +// end catch_debugger.h +// start catch_run_context.h + +// start catch_fatal_condition.h + +// start catch_windows_h_proxy.h + + +#if defined(CATCH_PLATFORM_WINDOWS) + +#if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX) +# define CATCH_DEFINED_NOMINMAX +# define NOMINMAX +#endif +#if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN) +# define CATCH_DEFINED_WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +#endif + +#ifdef __AFXDLL +#include <AfxWin.h> +#else +#include <windows.h> +#endif + +#ifdef CATCH_DEFINED_NOMINMAX +# undef NOMINMAX +#endif +#ifdef CATCH_DEFINED_WIN32_LEAN_AND_MEAN +# undef WIN32_LEAN_AND_MEAN +#endif + +#endif // defined(CATCH_PLATFORM_WINDOWS) + +// end catch_windows_h_proxy.h +#if defined( CATCH_CONFIG_WINDOWS_SEH ) + +namespace Catch { + + struct FatalConditionHandler { + + static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo); + FatalConditionHandler(); + static void reset(); + ~FatalConditionHandler(); + + private: + static bool isSet; + static ULONG guaranteeSize; + static PVOID exceptionHandlerHandle; + }; + +} // namespace Catch + +#elif defined ( CATCH_CONFIG_POSIX_SIGNALS ) + +#include <signal.h> + +namespace Catch { + + struct FatalConditionHandler { + + static bool isSet; + static struct sigaction oldSigActions[]; + static stack_t oldSigStack; + static char altStackMem[]; + + static void handleSignal( int sig ); + + FatalConditionHandler(); + ~FatalConditionHandler(); + static void reset(); + }; + +} // namespace Catch + +#else + +namespace Catch { + struct FatalConditionHandler { + void reset(); + }; +} + +#endif + +// end catch_fatal_condition.h +#include <string> + +namespace Catch { + + struct IMutableContext; + + /////////////////////////////////////////////////////////////////////////// + + class RunContext : public IResultCapture, public IRunner { + + public: + RunContext( RunContext const& ) = delete; + RunContext& operator =( RunContext const& ) = delete; + + explicit RunContext( IConfigPtr const& _config, IStreamingReporterPtr&& reporter ); + + ~RunContext() override; + + void testGroupStarting( std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount ); + void testGroupEnded( std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount ); + + Totals runTest(TestCase const& testCase); + + IConfigPtr config() const; + IStreamingReporter& reporter() const; + + public: // IResultCapture + + // Assertion handlers + void handleExpr + ( AssertionInfo const& info, + ITransientExpression const& expr, + AssertionReaction& reaction ) override; + void handleMessage + ( AssertionInfo const& info, + ResultWas::OfType resultType, + StringRef const& message, + AssertionReaction& reaction ) override; + void handleUnexpectedExceptionNotThrown + ( AssertionInfo const& info, + AssertionReaction& reaction ) override; + void handleUnexpectedInflightException + ( AssertionInfo const& info, + std::string const& message, + AssertionReaction& reaction ) override; + void handleIncomplete + ( AssertionInfo const& info ) override; + void handleNonExpr + ( AssertionInfo const &info, + ResultWas::OfType resultType, + AssertionReaction &reaction ) override; + + bool sectionStarted( SectionInfo const& sectionInfo, Counts& assertions ) override; + + void sectionEnded( SectionEndInfo const& endInfo ) override; + void sectionEndedEarly( SectionEndInfo const& endInfo ) override; + + auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& override; + + void benchmarkStarting( BenchmarkInfo const& info ) override; + void benchmarkEnded( BenchmarkStats const& stats ) override; + + void pushScopedMessage( MessageInfo const& message ) override; + void popScopedMessage( MessageInfo const& message ) override; + + void emplaceUnscopedMessage( MessageBuilder const& builder ) override; + + std::string getCurrentTestName() const override; + + const AssertionResult* getLastResult() const override; + + void exceptionEarlyReported() override; + + void handleFatalErrorCondition( StringRef message ) override; + + bool lastAssertionPassed() override; + + void assertionPassed() override; + + public: + // !TBD We need to do this another way! + bool aborting() const final; + + private: + + void runCurrentTest( std::string& redirectedCout, std::string& redirectedCerr ); + void invokeActiveTestCase(); + + void resetAssertionInfo(); + bool testForMissingAssertions( Counts& assertions ); + + void assertionEnded( AssertionResult const& result ); + void reportExpr + ( AssertionInfo const &info, + ResultWas::OfType resultType, + ITransientExpression const *expr, + bool negated ); + + void populateReaction( AssertionReaction& reaction ); + + private: + + void handleUnfinishedSections(); + + TestRunInfo m_runInfo; + IMutableContext& m_context; + TestCase const* m_activeTestCase = nullptr; + ITracker* m_testCaseTracker = nullptr; + Option<AssertionResult> m_lastResult; + + IConfigPtr m_config; + Totals m_totals; + IStreamingReporterPtr m_reporter; + std::vector<MessageInfo> m_messages; + std::vector<ScopedMessage> m_messageScopes; /* Keeps owners of so-called unscoped messages. */ + AssertionInfo m_lastAssertionInfo; + std::vector<SectionEndInfo> m_unfinishedSections; + std::vector<ITracker*> m_activeSections; + TrackerContext m_trackerContext; + bool m_lastAssertionPassed = false; + bool m_shouldReportUnexpected = true; + bool m_includeSuccessfulResults; + }; + +} // end namespace Catch + +// end catch_run_context.h +namespace Catch { + + namespace { + auto operator <<( std::ostream& os, ITransientExpression const& expr ) -> std::ostream& { + expr.streamReconstructedExpression( os ); + return os; + } + } + + LazyExpression::LazyExpression( bool isNegated ) + : m_isNegated( isNegated ) + {} + + LazyExpression::LazyExpression( LazyExpression const& other ) : m_isNegated( other.m_isNegated ) {} + + LazyExpression::operator bool() const { + return m_transientExpression != nullptr; + } + + auto operator << ( std::ostream& os, LazyExpression const& lazyExpr ) -> std::ostream& { + if( lazyExpr.m_isNegated ) + os << "!"; + + if( lazyExpr ) { + if( lazyExpr.m_isNegated && lazyExpr.m_transientExpression->isBinaryExpression() ) + os << "(" << *lazyExpr.m_transientExpression << ")"; + else + os << *lazyExpr.m_transientExpression; + } + else { + os << "{** error - unchecked empty expression requested **}"; + } + return os; + } + + AssertionHandler::AssertionHandler + ( StringRef const& macroName, + SourceLineInfo const& lineInfo, + StringRef capturedExpression, + ResultDisposition::Flags resultDisposition ) + : m_assertionInfo{ macroName, lineInfo, capturedExpression, resultDisposition }, + m_resultCapture( getResultCapture() ) + {} + + void AssertionHandler::handleExpr( ITransientExpression const& expr ) { + m_resultCapture.handleExpr( m_assertionInfo, expr, m_reaction ); + } + void AssertionHandler::handleMessage(ResultWas::OfType resultType, StringRef const& message) { + m_resultCapture.handleMessage( m_assertionInfo, resultType, message, m_reaction ); + } + + auto AssertionHandler::allowThrows() const -> bool { + return getCurrentContext().getConfig()->allowThrows(); + } + + void AssertionHandler::complete() { + setCompleted(); + if( m_reaction.shouldDebugBreak ) { + + // If you find your debugger stopping you here then go one level up on the + // call-stack for the code that caused it (typically a failed assertion) + + // (To go back to the test and change execution, jump over the throw, next) + CATCH_BREAK_INTO_DEBUGGER(); + } + if (m_reaction.shouldThrow) { +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + throw Catch::TestFailureException(); +#else + CATCH_ERROR( "Test failure requires aborting test!" ); +#endif + } + } + void AssertionHandler::setCompleted() { + m_completed = true; + } + + void AssertionHandler::handleUnexpectedInflightException() { + m_resultCapture.handleUnexpectedInflightException( m_assertionInfo, Catch::translateActiveException(), m_reaction ); + } + + void AssertionHandler::handleExceptionThrownAsExpected() { + m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction); + } + void AssertionHandler::handleExceptionNotThrownAsExpected() { + m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction); + } + + void AssertionHandler::handleUnexpectedExceptionNotThrown() { + m_resultCapture.handleUnexpectedExceptionNotThrown( m_assertionInfo, m_reaction ); + } + + void AssertionHandler::handleThrowingCallSkipped() { + m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction); + } + + // This is the overload that takes a string and infers the Equals matcher from it + // The more general overload, that takes any string matcher, is in catch_capture_matchers.cpp + void handleExceptionMatchExpr( AssertionHandler& handler, std::string const& str, StringRef const& matcherString ) { + handleExceptionMatchExpr( handler, Matchers::Equals( str ), matcherString ); + } + +} // namespace Catch +// end catch_assertionhandler.cpp +// start catch_assertionresult.cpp + +namespace Catch { + AssertionResultData::AssertionResultData(ResultWas::OfType _resultType, LazyExpression const & _lazyExpression): + lazyExpression(_lazyExpression), + resultType(_resultType) {} + + std::string AssertionResultData::reconstructExpression() const { + + if( reconstructedExpression.empty() ) { + if( lazyExpression ) { + ReusableStringStream rss; + rss << lazyExpression; + reconstructedExpression = rss.str(); + } + } + return reconstructedExpression; + } + + AssertionResult::AssertionResult( AssertionInfo const& info, AssertionResultData const& data ) + : m_info( info ), + m_resultData( data ) + {} + + // Result was a success + bool AssertionResult::succeeded() const { + return Catch::isOk( m_resultData.resultType ); + } + + // Result was a success, or failure is suppressed + bool AssertionResult::isOk() const { + return Catch::isOk( m_resultData.resultType ) || shouldSuppressFailure( m_info.resultDisposition ); + } + + ResultWas::OfType AssertionResult::getResultType() const { + return m_resultData.resultType; + } + + bool AssertionResult::hasExpression() const { + return m_info.capturedExpression[0] != 0; + } + + bool AssertionResult::hasMessage() const { + return !m_resultData.message.empty(); + } + + std::string AssertionResult::getExpression() const { + if( isFalseTest( m_info.resultDisposition ) ) + return "!(" + m_info.capturedExpression + ")"; + else + return m_info.capturedExpression; + } + + std::string AssertionResult::getExpressionInMacro() const { + std::string expr; + if( m_info.macroName[0] == 0 ) + expr = m_info.capturedExpression; + else { + expr.reserve( m_info.macroName.size() + m_info.capturedExpression.size() + 4 ); + expr += m_info.macroName; + expr += "( "; + expr += m_info.capturedExpression; + expr += " )"; + } + return expr; + } + + bool AssertionResult::hasExpandedExpression() const { + return hasExpression() && getExpandedExpression() != getExpression(); + } + + std::string AssertionResult::getExpandedExpression() const { + std::string expr = m_resultData.reconstructExpression(); + return expr.empty() + ? getExpression() + : expr; + } + + std::string AssertionResult::getMessage() const { + return m_resultData.message; + } + SourceLineInfo AssertionResult::getSourceInfo() const { + return m_info.lineInfo; + } + + StringRef AssertionResult::getTestMacroName() const { + return m_info.macroName; + } + +} // end namespace Catch +// end catch_assertionresult.cpp +// start catch_benchmark.cpp + +namespace Catch { + + auto BenchmarkLooper::getResolution() -> uint64_t { + return getEstimatedClockResolution() * getCurrentContext().getConfig()->benchmarkResolutionMultiple(); + } + + void BenchmarkLooper::reportStart() { + getResultCapture().benchmarkStarting( { m_name } ); + } + auto BenchmarkLooper::needsMoreIterations() -> bool { + auto elapsed = m_timer.getElapsedNanoseconds(); + + // Exponentially increasing iterations until we're confident in our timer resolution + if( elapsed < m_resolution ) { + m_iterationsToRun *= 10; + return true; + } + + getResultCapture().benchmarkEnded( { { m_name }, m_count, elapsed } ); + return false; + } + +} // end namespace Catch +// end catch_benchmark.cpp +// start catch_capture_matchers.cpp + +namespace Catch { + + using StringMatcher = Matchers::Impl::MatcherBase<std::string>; + + // This is the general overload that takes a any string matcher + // There is another overload, in catch_assertionhandler.h/.cpp, that only takes a string and infers + // the Equals matcher (so the header does not mention matchers) + void handleExceptionMatchExpr( AssertionHandler& handler, StringMatcher const& matcher, StringRef const& matcherString ) { + std::string exceptionMessage = Catch::translateActiveException(); + MatchExpr<std::string, StringMatcher const&> expr( exceptionMessage, matcher, matcherString ); + handler.handleExpr( expr ); + } + +} // namespace Catch +// end catch_capture_matchers.cpp +// start catch_commandline.cpp + +// start catch_commandline.h + +// start catch_clara.h + +// Use Catch's value for console width (store Clara's off to the side, if present) +#ifdef CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#undef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#endif +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH-1 + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wweak-vtables" +#pragma clang diagnostic ignored "-Wexit-time-destructors" +#pragma clang diagnostic ignored "-Wshadow" +#endif + +// start clara.hpp +// Copyright 2017 Two Blue Cubes Ltd. All rights reserved. +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// See https://github.com/philsquared/Clara for more details + +// Clara v1.1.5 + + +#ifndef CATCH_CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_CONFIG_CONSOLE_WIDTH 80 +#endif + +#ifndef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_CLARA_CONFIG_CONSOLE_WIDTH +#endif + +#ifndef CLARA_CONFIG_OPTIONAL_TYPE +#ifdef __has_include +#if __has_include(<optional>) && __cplusplus >= 201703L +#include <optional> +#define CLARA_CONFIG_OPTIONAL_TYPE std::optional +#endif +#endif +#endif + +// ----------- #included from clara_textflow.hpp ----------- + +// TextFlowCpp +// +// A single-header library for wrapping and laying out basic text, by Phil Nash +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// This project is hosted at https://github.com/philsquared/textflowcpp + + +#include <cassert> +#include <ostream> +#include <sstream> +#include <vector> + +#ifndef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH 80 +#endif + +namespace Catch { +namespace clara { +namespace TextFlow { + +inline auto isWhitespace(char c) -> bool { + static std::string chars = " \t\n\r"; + return chars.find(c) != std::string::npos; +} +inline auto isBreakableBefore(char c) -> bool { + static std::string chars = "[({<|"; + return chars.find(c) != std::string::npos; +} +inline auto isBreakableAfter(char c) -> bool { + static std::string chars = "])}>.,:;*+-=&/\\"; + return chars.find(c) != std::string::npos; +} + +class Columns; + +class Column { + std::vector<std::string> m_strings; + size_t m_width = CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH; + size_t m_indent = 0; + size_t m_initialIndent = std::string::npos; + +public: + class iterator { + friend Column; + + Column const& m_column; + size_t m_stringIndex = 0; + size_t m_pos = 0; + + size_t m_len = 0; + size_t m_end = 0; + bool m_suffix = false; + + iterator(Column const& column, size_t stringIndex) + : m_column(column), + m_stringIndex(stringIndex) {} + + auto line() const -> std::string const& { return m_column.m_strings[m_stringIndex]; } + + auto isBoundary(size_t at) const -> bool { + assert(at > 0); + assert(at <= line().size()); + + return at == line().size() || + (isWhitespace(line()[at]) && !isWhitespace(line()[at - 1])) || + isBreakableBefore(line()[at]) || + isBreakableAfter(line()[at - 1]); + } + + void calcLength() { + assert(m_stringIndex < m_column.m_strings.size()); + + m_suffix = false; + auto width = m_column.m_width - indent(); + m_end = m_pos; + while (m_end < line().size() && line()[m_end] != '\n') + ++m_end; + + if (m_end < m_pos + width) { + m_len = m_end - m_pos; + } else { + size_t len = width; + while (len > 0 && !isBoundary(m_pos + len)) + --len; + while (len > 0 && isWhitespace(line()[m_pos + len - 1])) + --len; + + if (len > 0) { + m_len = len; + } else { + m_suffix = true; + m_len = width - 1; + } + } + } + + auto indent() const -> size_t { + auto initial = m_pos == 0 && m_stringIndex == 0 ? m_column.m_initialIndent : std::string::npos; + return initial == std::string::npos ? m_column.m_indent : initial; + } + + auto addIndentAndSuffix(std::string const &plain) const -> std::string { + return std::string(indent(), ' ') + (m_suffix ? plain + "-" : plain); + } + + public: + using difference_type = std::ptrdiff_t; + using value_type = std::string; + using pointer = value_type * ; + using reference = value_type & ; + using iterator_category = std::forward_iterator_tag; + + explicit iterator(Column const& column) : m_column(column) { + assert(m_column.m_width > m_column.m_indent); + assert(m_column.m_initialIndent == std::string::npos || m_column.m_width > m_column.m_initialIndent); + calcLength(); + if (m_len == 0) + m_stringIndex++; // Empty string + } + + auto operator *() const -> std::string { + assert(m_stringIndex < m_column.m_strings.size()); + assert(m_pos <= m_end); + return addIndentAndSuffix(line().substr(m_pos, m_len)); + } + + auto operator ++() -> iterator& { + m_pos += m_len; + if (m_pos < line().size() && line()[m_pos] == '\n') + m_pos += 1; + else + while (m_pos < line().size() && isWhitespace(line()[m_pos])) + ++m_pos; + + if (m_pos == line().size()) { + m_pos = 0; + ++m_stringIndex; + } + if (m_stringIndex < m_column.m_strings.size()) + calcLength(); + return *this; + } + auto operator ++(int) -> iterator { + iterator prev(*this); + operator++(); + return prev; + } + + auto operator ==(iterator const& other) const -> bool { + return + m_pos == other.m_pos && + m_stringIndex == other.m_stringIndex && + &m_column == &other.m_column; + } + auto operator !=(iterator const& other) const -> bool { + return !operator==(other); + } + }; + using const_iterator = iterator; + + explicit Column(std::string const& text) { m_strings.push_back(text); } + + auto width(size_t newWidth) -> Column& { + assert(newWidth > 0); + m_width = newWidth; + return *this; + } + auto indent(size_t newIndent) -> Column& { + m_indent = newIndent; + return *this; + } + auto initialIndent(size_t newIndent) -> Column& { + m_initialIndent = newIndent; + return *this; + } + + auto width() const -> size_t { return m_width; } + auto begin() const -> iterator { return iterator(*this); } + auto end() const -> iterator { return { *this, m_strings.size() }; } + + inline friend std::ostream& operator << (std::ostream& os, Column const& col) { + bool first = true; + for (auto line : col) { + if (first) + first = false; + else + os << "\n"; + os << line; + } + return os; + } + + auto operator + (Column const& other)->Columns; + + auto toString() const -> std::string { + std::ostringstream oss; + oss << *this; + return oss.str(); + } +}; + +class Spacer : public Column { + +public: + explicit Spacer(size_t spaceWidth) : Column("") { + width(spaceWidth); + } +}; + +class Columns { + std::vector<Column> m_columns; + +public: + + class iterator { + friend Columns; + struct EndTag {}; + + std::vector<Column> const& m_columns; + std::vector<Column::iterator> m_iterators; + size_t m_activeIterators; + + iterator(Columns const& columns, EndTag) + : m_columns(columns.m_columns), + m_activeIterators(0) { + m_iterators.reserve(m_columns.size()); + + for (auto const& col : m_columns) + m_iterators.push_back(col.end()); + } + + public: + using difference_type = std::ptrdiff_t; + using value_type = std::string; + using pointer = value_type * ; + using reference = value_type & ; + using iterator_category = std::forward_iterator_tag; + + explicit iterator(Columns const& columns) + : m_columns(columns.m_columns), + m_activeIterators(m_columns.size()) { + m_iterators.reserve(m_columns.size()); + + for (auto const& col : m_columns) + m_iterators.push_back(col.begin()); + } + + auto operator ==(iterator const& other) const -> bool { + return m_iterators == other.m_iterators; + } + auto operator !=(iterator const& other) const -> bool { + return m_iterators != other.m_iterators; + } + auto operator *() const -> std::string { + std::string row, padding; + + for (size_t i = 0; i < m_columns.size(); ++i) { + auto width = m_columns[i].width(); + if (m_iterators[i] != m_columns[i].end()) { + std::string col = *m_iterators[i]; + row += padding + col; + if (col.size() < width) + padding = std::string(width - col.size(), ' '); + else + padding = ""; + } else { + padding += std::string(width, ' '); + } + } + return row; + } + auto operator ++() -> iterator& { + for (size_t i = 0; i < m_columns.size(); ++i) { + if (m_iterators[i] != m_columns[i].end()) + ++m_iterators[i]; + } + return *this; + } + auto operator ++(int) -> iterator { + iterator prev(*this); + operator++(); + return prev; + } + }; + using const_iterator = iterator; + + auto begin() const -> iterator { return iterator(*this); } + auto end() const -> iterator { return { *this, iterator::EndTag() }; } + + auto operator += (Column const& col) -> Columns& { + m_columns.push_back(col); + return *this; + } + auto operator + (Column const& col) -> Columns { + Columns combined = *this; + combined += col; + return combined; + } + + inline friend std::ostream& operator << (std::ostream& os, Columns const& cols) { + + bool first = true; + for (auto line : cols) { + if (first) + first = false; + else + os << "\n"; + os << line; + } + return os; + } + + auto toString() const -> std::string { + std::ostringstream oss; + oss << *this; + return oss.str(); + } +}; + +inline auto Column::operator + (Column const& other) -> Columns { + Columns cols; + cols += *this; + cols += other; + return cols; +} +} + +} +} + +// ----------- end of #include from clara_textflow.hpp ----------- +// ........... back in clara.hpp + +#include <cctype> +#include <string> +#include <memory> +#include <set> +#include <algorithm> + +#if !defined(CATCH_PLATFORM_WINDOWS) && ( defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) ) +#define CATCH_PLATFORM_WINDOWS +#endif + +namespace Catch { namespace clara { +namespace detail { + + // Traits for extracting arg and return type of lambdas (for single argument lambdas) + template<typename L> + struct UnaryLambdaTraits : UnaryLambdaTraits<decltype( &L::operator() )> {}; + + template<typename ClassT, typename ReturnT, typename... Args> + struct UnaryLambdaTraits<ReturnT( ClassT::* )( Args... ) const> { + static const bool isValid = false; + }; + + template<typename ClassT, typename ReturnT, typename ArgT> + struct UnaryLambdaTraits<ReturnT( ClassT::* )( ArgT ) const> { + static const bool isValid = true; + using ArgType = typename std::remove_const<typename std::remove_reference<ArgT>::type>::type; + using ReturnType = ReturnT; + }; + + class TokenStream; + + // Transport for raw args (copied from main args, or supplied via init list for testing) + class Args { + friend TokenStream; + std::string m_exeName; + std::vector<std::string> m_args; + + public: + Args( int argc, char const* const* argv ) + : m_exeName(argv[0]), + m_args(argv + 1, argv + argc) {} + + Args( std::initializer_list<std::string> args ) + : m_exeName( *args.begin() ), + m_args( args.begin()+1, args.end() ) + {} + + auto exeName() const -> std::string { + return m_exeName; + } + }; + + // Wraps a token coming from a token stream. These may not directly correspond to strings as a single string + // may encode an option + its argument if the : or = form is used + enum class TokenType { + Option, Argument + }; + struct Token { + TokenType type; + std::string token; + }; + + inline auto isOptPrefix( char c ) -> bool { + return c == '-' +#ifdef CATCH_PLATFORM_WINDOWS + || c == '/' +#endif + ; + } + + // Abstracts iterators into args as a stream of tokens, with option arguments uniformly handled + class TokenStream { + using Iterator = std::vector<std::string>::const_iterator; + Iterator it; + Iterator itEnd; + std::vector<Token> m_tokenBuffer; + + void loadBuffer() { + m_tokenBuffer.resize( 0 ); + + // Skip any empty strings + while( it != itEnd && it->empty() ) + ++it; + + if( it != itEnd ) { + auto const &next = *it; + if( isOptPrefix( next[0] ) ) { + auto delimiterPos = next.find_first_of( " :=" ); + if( delimiterPos != std::string::npos ) { + m_tokenBuffer.push_back( { TokenType::Option, next.substr( 0, delimiterPos ) } ); + m_tokenBuffer.push_back( { TokenType::Argument, next.substr( delimiterPos + 1 ) } ); + } else { + if( next[1] != '-' && next.size() > 2 ) { + std::string opt = "- "; + for( size_t i = 1; i < next.size(); ++i ) { + opt[1] = next[i]; + m_tokenBuffer.push_back( { TokenType::Option, opt } ); + } + } else { + m_tokenBuffer.push_back( { TokenType::Option, next } ); + } + } + } else { + m_tokenBuffer.push_back( { TokenType::Argument, next } ); + } + } + } + + public: + explicit TokenStream( Args const &args ) : TokenStream( args.m_args.begin(), args.m_args.end() ) {} + + TokenStream( Iterator it, Iterator itEnd ) : it( it ), itEnd( itEnd ) { + loadBuffer(); + } + + explicit operator bool() const { + return !m_tokenBuffer.empty() || it != itEnd; + } + + auto count() const -> size_t { return m_tokenBuffer.size() + (itEnd - it); } + + auto operator*() const -> Token { + assert( !m_tokenBuffer.empty() ); + return m_tokenBuffer.front(); + } + + auto operator->() const -> Token const * { + assert( !m_tokenBuffer.empty() ); + return &m_tokenBuffer.front(); + } + + auto operator++() -> TokenStream & { + if( m_tokenBuffer.size() >= 2 ) { + m_tokenBuffer.erase( m_tokenBuffer.begin() ); + } else { + if( it != itEnd ) + ++it; + loadBuffer(); + } + return *this; + } + }; + + class ResultBase { + public: + enum Type { + Ok, LogicError, RuntimeError + }; + + protected: + ResultBase( Type type ) : m_type( type ) {} + virtual ~ResultBase() = default; + + virtual void enforceOk() const = 0; + + Type m_type; + }; + + template<typename T> + class ResultValueBase : public ResultBase { + public: + auto value() const -> T const & { + enforceOk(); + return m_value; + } + + protected: + ResultValueBase( Type type ) : ResultBase( type ) {} + + ResultValueBase( ResultValueBase const &other ) : ResultBase( other ) { + if( m_type == ResultBase::Ok ) + new( &m_value ) T( other.m_value ); + } + + ResultValueBase( Type, T const &value ) : ResultBase( Ok ) { + new( &m_value ) T( value ); + } + + auto operator=( ResultValueBase const &other ) -> ResultValueBase & { + if( m_type == ResultBase::Ok ) + m_value.~T(); + ResultBase::operator=(other); + if( m_type == ResultBase::Ok ) + new( &m_value ) T( other.m_value ); + return *this; + } + + ~ResultValueBase() override { + if( m_type == Ok ) + m_value.~T(); + } + + union { + T m_value; + }; + }; + + template<> + class ResultValueBase<void> : public ResultBase { + protected: + using ResultBase::ResultBase; + }; + + template<typename T = void> + class BasicResult : public ResultValueBase<T> { + public: + template<typename U> + explicit BasicResult( BasicResult<U> const &other ) + : ResultValueBase<T>( other.type() ), + m_errorMessage( other.errorMessage() ) + { + assert( type() != ResultBase::Ok ); + } + + template<typename U> + static auto ok( U const &value ) -> BasicResult { return { ResultBase::Ok, value }; } + static auto ok() -> BasicResult { return { ResultBase::Ok }; } + static auto logicError( std::string const &message ) -> BasicResult { return { ResultBase::LogicError, message }; } + static auto runtimeError( std::string const &message ) -> BasicResult { return { ResultBase::RuntimeError, message }; } + + explicit operator bool() const { return m_type == ResultBase::Ok; } + auto type() const -> ResultBase::Type { return m_type; } + auto errorMessage() const -> std::string { return m_errorMessage; } + + protected: + void enforceOk() const override { + + // Errors shouldn't reach this point, but if they do + // the actual error message will be in m_errorMessage + assert( m_type != ResultBase::LogicError ); + assert( m_type != ResultBase::RuntimeError ); + if( m_type != ResultBase::Ok ) + std::abort(); + } + + std::string m_errorMessage; // Only populated if resultType is an error + + BasicResult( ResultBase::Type type, std::string const &message ) + : ResultValueBase<T>(type), + m_errorMessage(message) + { + assert( m_type != ResultBase::Ok ); + } + + using ResultValueBase<T>::ResultValueBase; + using ResultBase::m_type; + }; + + enum class ParseResultType { + Matched, NoMatch, ShortCircuitAll, ShortCircuitSame + }; + + class ParseState { + public: + + ParseState( ParseResultType type, TokenStream const &remainingTokens ) + : m_type(type), + m_remainingTokens( remainingTokens ) + {} + + auto type() const -> ParseResultType { return m_type; } + auto remainingTokens() const -> TokenStream { return m_remainingTokens; } + + private: + ParseResultType m_type; + TokenStream m_remainingTokens; + }; + + using Result = BasicResult<void>; + using ParserResult = BasicResult<ParseResultType>; + using InternalParseResult = BasicResult<ParseState>; + + struct HelpColumns { + std::string left; + std::string right; + }; + + template<typename T> + inline auto convertInto( std::string const &source, T& target ) -> ParserResult { + std::stringstream ss; + ss << source; + ss >> target; + if( ss.fail() ) + return ParserResult::runtimeError( "Unable to convert '" + source + "' to destination type" ); + else + return ParserResult::ok( ParseResultType::Matched ); + } + inline auto convertInto( std::string const &source, std::string& target ) -> ParserResult { + target = source; + return ParserResult::ok( ParseResultType::Matched ); + } + inline auto convertInto( std::string const &source, bool &target ) -> ParserResult { + std::string srcLC = source; + std::transform( srcLC.begin(), srcLC.end(), srcLC.begin(), []( char c ) { return static_cast<char>( std::tolower(c) ); } ); + if (srcLC == "y" || srcLC == "1" || srcLC == "true" || srcLC == "yes" || srcLC == "on") + target = true; + else if (srcLC == "n" || srcLC == "0" || srcLC == "false" || srcLC == "no" || srcLC == "off") + target = false; + else + return ParserResult::runtimeError( "Expected a boolean value but did not recognise: '" + source + "'" ); + return ParserResult::ok( ParseResultType::Matched ); + } +#ifdef CLARA_CONFIG_OPTIONAL_TYPE + template<typename T> + inline auto convertInto( std::string const &source, CLARA_CONFIG_OPTIONAL_TYPE<T>& target ) -> ParserResult { + T temp; + auto result = convertInto( source, temp ); + if( result ) + target = std::move(temp); + return result; + } +#endif // CLARA_CONFIG_OPTIONAL_TYPE + + struct NonCopyable { + NonCopyable() = default; + NonCopyable( NonCopyable const & ) = delete; + NonCopyable( NonCopyable && ) = delete; + NonCopyable &operator=( NonCopyable const & ) = delete; + NonCopyable &operator=( NonCopyable && ) = delete; + }; + + struct BoundRef : NonCopyable { + virtual ~BoundRef() = default; + virtual auto isContainer() const -> bool { return false; } + virtual auto isFlag() const -> bool { return false; } + }; + struct BoundValueRefBase : BoundRef { + virtual auto setValue( std::string const &arg ) -> ParserResult = 0; + }; + struct BoundFlagRefBase : BoundRef { + virtual auto setFlag( bool flag ) -> ParserResult = 0; + virtual auto isFlag() const -> bool { return true; } + }; + + template<typename T> + struct BoundValueRef : BoundValueRefBase { + T &m_ref; + + explicit BoundValueRef( T &ref ) : m_ref( ref ) {} + + auto setValue( std::string const &arg ) -> ParserResult override { + return convertInto( arg, m_ref ); + } + }; + + template<typename T> + struct BoundValueRef<std::vector<T>> : BoundValueRefBase { + std::vector<T> &m_ref; + + explicit BoundValueRef( std::vector<T> &ref ) : m_ref( ref ) {} + + auto isContainer() const -> bool override { return true; } + + auto setValue( std::string const &arg ) -> ParserResult override { + T temp; + auto result = convertInto( arg, temp ); + if( result ) + m_ref.push_back( temp ); + return result; + } + }; + + struct BoundFlagRef : BoundFlagRefBase { + bool &m_ref; + + explicit BoundFlagRef( bool &ref ) : m_ref( ref ) {} + + auto setFlag( bool flag ) -> ParserResult override { + m_ref = flag; + return ParserResult::ok( ParseResultType::Matched ); + } + }; + + template<typename ReturnType> + struct LambdaInvoker { + static_assert( std::is_same<ReturnType, ParserResult>::value, "Lambda must return void or clara::ParserResult" ); + + template<typename L, typename ArgType> + static auto invoke( L const &lambda, ArgType const &arg ) -> ParserResult { + return lambda( arg ); + } + }; + + template<> + struct LambdaInvoker<void> { + template<typename L, typename ArgType> + static auto invoke( L const &lambda, ArgType const &arg ) -> ParserResult { + lambda( arg ); + return ParserResult::ok( ParseResultType::Matched ); + } + }; + + template<typename ArgType, typename L> + inline auto invokeLambda( L const &lambda, std::string const &arg ) -> ParserResult { + ArgType temp{}; + auto result = convertInto( arg, temp ); + return !result + ? result + : LambdaInvoker<typename UnaryLambdaTraits<L>::ReturnType>::invoke( lambda, temp ); + } + + template<typename L> + struct BoundLambda : BoundValueRefBase { + L m_lambda; + + static_assert( UnaryLambdaTraits<L>::isValid, "Supplied lambda must take exactly one argument" ); + explicit BoundLambda( L const &lambda ) : m_lambda( lambda ) {} + + auto setValue( std::string const &arg ) -> ParserResult override { + return invokeLambda<typename UnaryLambdaTraits<L>::ArgType>( m_lambda, arg ); + } + }; + + template<typename L> + struct BoundFlagLambda : BoundFlagRefBase { + L m_lambda; + + static_assert( UnaryLambdaTraits<L>::isValid, "Supplied lambda must take exactly one argument" ); + static_assert( std::is_same<typename UnaryLambdaTraits<L>::ArgType, bool>::value, "flags must be boolean" ); + + explicit BoundFlagLambda( L const &lambda ) : m_lambda( lambda ) {} + + auto setFlag( bool flag ) -> ParserResult override { + return LambdaInvoker<typename UnaryLambdaTraits<L>::ReturnType>::invoke( m_lambda, flag ); + } + }; + + enum class Optionality { Optional, Required }; + + struct Parser; + + class ParserBase { + public: + virtual ~ParserBase() = default; + virtual auto validate() const -> Result { return Result::ok(); } + virtual auto parse( std::string const& exeName, TokenStream const &tokens) const -> InternalParseResult = 0; + virtual auto cardinality() const -> size_t { return 1; } + + auto parse( Args const &args ) const -> InternalParseResult { + return parse( args.exeName(), TokenStream( args ) ); + } + }; + + template<typename DerivedT> + class ComposableParserImpl : public ParserBase { + public: + template<typename T> + auto operator|( T const &other ) const -> Parser; + + template<typename T> + auto operator+( T const &other ) const -> Parser; + }; + + // Common code and state for Args and Opts + template<typename DerivedT> + class ParserRefImpl : public ComposableParserImpl<DerivedT> { + protected: + Optionality m_optionality = Optionality::Optional; + std::shared_ptr<BoundRef> m_ref; + std::string m_hint; + std::string m_description; + + explicit ParserRefImpl( std::shared_ptr<BoundRef> const &ref ) : m_ref( ref ) {} + + public: + template<typename T> + ParserRefImpl( T &ref, std::string const &hint ) + : m_ref( std::make_shared<BoundValueRef<T>>( ref ) ), + m_hint( hint ) + {} + + template<typename LambdaT> + ParserRefImpl( LambdaT const &ref, std::string const &hint ) + : m_ref( std::make_shared<BoundLambda<LambdaT>>( ref ) ), + m_hint(hint) + {} + + auto operator()( std::string const &description ) -> DerivedT & { + m_description = description; + return static_cast<DerivedT &>( *this ); + } + + auto optional() -> DerivedT & { + m_optionality = Optionality::Optional; + return static_cast<DerivedT &>( *this ); + }; + + auto required() -> DerivedT & { + m_optionality = Optionality::Required; + return static_cast<DerivedT &>( *this ); + }; + + auto isOptional() const -> bool { + return m_optionality == Optionality::Optional; + } + + auto cardinality() const -> size_t override { + if( m_ref->isContainer() ) + return 0; + else + return 1; + } + + auto hint() const -> std::string { return m_hint; } + }; + + class ExeName : public ComposableParserImpl<ExeName> { + std::shared_ptr<std::string> m_name; + std::shared_ptr<BoundValueRefBase> m_ref; + + template<typename LambdaT> + static auto makeRef(LambdaT const &lambda) -> std::shared_ptr<BoundValueRefBase> { + return std::make_shared<BoundLambda<LambdaT>>( lambda) ; + } + + public: + ExeName() : m_name( std::make_shared<std::string>( "<executable>" ) ) {} + + explicit ExeName( std::string &ref ) : ExeName() { + m_ref = std::make_shared<BoundValueRef<std::string>>( ref ); + } + + template<typename LambdaT> + explicit ExeName( LambdaT const& lambda ) : ExeName() { + m_ref = std::make_shared<BoundLambda<LambdaT>>( lambda ); + } + + // The exe name is not parsed out of the normal tokens, but is handled specially + auto parse( std::string const&, TokenStream const &tokens ) const -> InternalParseResult override { + return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, tokens ) ); + } + + auto name() const -> std::string { return *m_name; } + auto set( std::string const& newName ) -> ParserResult { + + auto lastSlash = newName.find_last_of( "\\/" ); + auto filename = ( lastSlash == std::string::npos ) + ? newName + : newName.substr( lastSlash+1 ); + + *m_name = filename; + if( m_ref ) + return m_ref->setValue( filename ); + else + return ParserResult::ok( ParseResultType::Matched ); + } + }; + + class Arg : public ParserRefImpl<Arg> { + public: + using ParserRefImpl::ParserRefImpl; + + auto parse( std::string const &, TokenStream const &tokens ) const -> InternalParseResult override { + auto validationResult = validate(); + if( !validationResult ) + return InternalParseResult( validationResult ); + + auto remainingTokens = tokens; + auto const &token = *remainingTokens; + if( token.type != TokenType::Argument ) + return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, remainingTokens ) ); + + assert( !m_ref->isFlag() ); + auto valueRef = static_cast<detail::BoundValueRefBase*>( m_ref.get() ); + + auto result = valueRef->setValue( remainingTokens->token ); + if( !result ) + return InternalParseResult( result ); + else + return InternalParseResult::ok( ParseState( ParseResultType::Matched, ++remainingTokens ) ); + } + }; + + inline auto normaliseOpt( std::string const &optName ) -> std::string { +#ifdef CATCH_PLATFORM_WINDOWS + if( optName[0] == '/' ) + return "-" + optName.substr( 1 ); + else +#endif + return optName; + } + + class Opt : public ParserRefImpl<Opt> { + protected: + std::vector<std::string> m_optNames; + + public: + template<typename LambdaT> + explicit Opt( LambdaT const &ref ) : ParserRefImpl( std::make_shared<BoundFlagLambda<LambdaT>>( ref ) ) {} + + explicit Opt( bool &ref ) : ParserRefImpl( std::make_shared<BoundFlagRef>( ref ) ) {} + + template<typename LambdaT> + Opt( LambdaT const &ref, std::string const &hint ) : ParserRefImpl( ref, hint ) {} + + template<typename T> + Opt( T &ref, std::string const &hint ) : ParserRefImpl( ref, hint ) {} + + auto operator[]( std::string const &optName ) -> Opt & { + m_optNames.push_back( optName ); + return *this; + } + + auto getHelpColumns() const -> std::vector<HelpColumns> { + std::ostringstream oss; + bool first = true; + for( auto const &opt : m_optNames ) { + if (first) + first = false; + else + oss << ", "; + oss << opt; + } + if( !m_hint.empty() ) + oss << " <" << m_hint << ">"; + return { { oss.str(), m_description } }; + } + + auto isMatch( std::string const &optToken ) const -> bool { + auto normalisedToken = normaliseOpt( optToken ); + for( auto const &name : m_optNames ) { + if( normaliseOpt( name ) == normalisedToken ) + return true; + } + return false; + } + + using ParserBase::parse; + + auto parse( std::string const&, TokenStream const &tokens ) const -> InternalParseResult override { + auto validationResult = validate(); + if( !validationResult ) + return InternalParseResult( validationResult ); + + auto remainingTokens = tokens; + if( remainingTokens && remainingTokens->type == TokenType::Option ) { + auto const &token = *remainingTokens; + if( isMatch(token.token ) ) { + if( m_ref->isFlag() ) { + auto flagRef = static_cast<detail::BoundFlagRefBase*>( m_ref.get() ); + auto result = flagRef->setFlag( true ); + if( !result ) + return InternalParseResult( result ); + if( result.value() == ParseResultType::ShortCircuitAll ) + return InternalParseResult::ok( ParseState( result.value(), remainingTokens ) ); + } else { + auto valueRef = static_cast<detail::BoundValueRefBase*>( m_ref.get() ); + ++remainingTokens; + if( !remainingTokens ) + return InternalParseResult::runtimeError( "Expected argument following " + token.token ); + auto const &argToken = *remainingTokens; + if( argToken.type != TokenType::Argument ) + return InternalParseResult::runtimeError( "Expected argument following " + token.token ); + auto result = valueRef->setValue( argToken.token ); + if( !result ) + return InternalParseResult( result ); + if( result.value() == ParseResultType::ShortCircuitAll ) + return InternalParseResult::ok( ParseState( result.value(), remainingTokens ) ); + } + return InternalParseResult::ok( ParseState( ParseResultType::Matched, ++remainingTokens ) ); + } + } + return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, remainingTokens ) ); + } + + auto validate() const -> Result override { + if( m_optNames.empty() ) + return Result::logicError( "No options supplied to Opt" ); + for( auto const &name : m_optNames ) { + if( name.empty() ) + return Result::logicError( "Option name cannot be empty" ); +#ifdef CATCH_PLATFORM_WINDOWS + if( name[0] != '-' && name[0] != '/' ) + return Result::logicError( "Option name must begin with '-' or '/'" ); +#else + if( name[0] != '-' ) + return Result::logicError( "Option name must begin with '-'" ); +#endif + } + return ParserRefImpl::validate(); + } + }; + + struct Help : Opt { + Help( bool &showHelpFlag ) + : Opt([&]( bool flag ) { + showHelpFlag = flag; + return ParserResult::ok( ParseResultType::ShortCircuitAll ); + }) + { + static_cast<Opt &>( *this ) + ("display usage information") + ["-?"]["-h"]["--help"] + .optional(); + } + }; + + struct Parser : ParserBase { + + mutable ExeName m_exeName; + std::vector<Opt> m_options; + std::vector<Arg> m_args; + + auto operator|=( ExeName const &exeName ) -> Parser & { + m_exeName = exeName; + return *this; + } + + auto operator|=( Arg const &arg ) -> Parser & { + m_args.push_back(arg); + return *this; + } + + auto operator|=( Opt const &opt ) -> Parser & { + m_options.push_back(opt); + return *this; + } + + auto operator|=( Parser const &other ) -> Parser & { + m_options.insert(m_options.end(), other.m_options.begin(), other.m_options.end()); + m_args.insert(m_args.end(), other.m_args.begin(), other.m_args.end()); + return *this; + } + + template<typename T> + auto operator|( T const &other ) const -> Parser { + return Parser( *this ) |= other; + } + + // Forward deprecated interface with '+' instead of '|' + template<typename T> + auto operator+=( T const &other ) -> Parser & { return operator|=( other ); } + template<typename T> + auto operator+( T const &other ) const -> Parser { return operator|( other ); } + + auto getHelpColumns() const -> std::vector<HelpColumns> { + std::vector<HelpColumns> cols; + for (auto const &o : m_options) { + auto childCols = o.getHelpColumns(); + cols.insert( cols.end(), childCols.begin(), childCols.end() ); + } + return cols; + } + + void writeToStream( std::ostream &os ) const { + if (!m_exeName.name().empty()) { + os << "usage:\n" << " " << m_exeName.name() << " "; + bool required = true, first = true; + for( auto const &arg : m_args ) { + if (first) + first = false; + else + os << " "; + if( arg.isOptional() && required ) { + os << "["; + required = false; + } + os << "<" << arg.hint() << ">"; + if( arg.cardinality() == 0 ) + os << " ... "; + } + if( !required ) + os << "]"; + if( !m_options.empty() ) + os << " options"; + os << "\n\nwhere options are:" << std::endl; + } + + auto rows = getHelpColumns(); + size_t consoleWidth = CATCH_CLARA_CONFIG_CONSOLE_WIDTH; + size_t optWidth = 0; + for( auto const &cols : rows ) + optWidth = (std::max)(optWidth, cols.left.size() + 2); + + optWidth = (std::min)(optWidth, consoleWidth/2); + + for( auto const &cols : rows ) { + auto row = + TextFlow::Column( cols.left ).width( optWidth ).indent( 2 ) + + TextFlow::Spacer(4) + + TextFlow::Column( cols.right ).width( consoleWidth - 7 - optWidth ); + os << row << std::endl; + } + } + + friend auto operator<<( std::ostream &os, Parser const &parser ) -> std::ostream& { + parser.writeToStream( os ); + return os; + } + + auto validate() const -> Result override { + for( auto const &opt : m_options ) { + auto result = opt.validate(); + if( !result ) + return result; + } + for( auto const &arg : m_args ) { + auto result = arg.validate(); + if( !result ) + return result; + } + return Result::ok(); + } + + using ParserBase::parse; + + auto parse( std::string const& exeName, TokenStream const &tokens ) const -> InternalParseResult override { + + struct ParserInfo { + ParserBase const* parser = nullptr; + size_t count = 0; + }; + const size_t totalParsers = m_options.size() + m_args.size(); + assert( totalParsers < 512 ); + // ParserInfo parseInfos[totalParsers]; // <-- this is what we really want to do + ParserInfo parseInfos[512]; + + { + size_t i = 0; + for (auto const &opt : m_options) parseInfos[i++].parser = &opt; + for (auto const &arg : m_args) parseInfos[i++].parser = &arg; + } + + m_exeName.set( exeName ); + + auto result = InternalParseResult::ok( ParseState( ParseResultType::NoMatch, tokens ) ); + while( result.value().remainingTokens() ) { + bool tokenParsed = false; + + for( size_t i = 0; i < totalParsers; ++i ) { + auto& parseInfo = parseInfos[i]; + if( parseInfo.parser->cardinality() == 0 || parseInfo.count < parseInfo.parser->cardinality() ) { + result = parseInfo.parser->parse(exeName, result.value().remainingTokens()); + if (!result) + return result; + if (result.value().type() != ParseResultType::NoMatch) { + tokenParsed = true; + ++parseInfo.count; + break; + } + } + } + + if( result.value().type() == ParseResultType::ShortCircuitAll ) + return result; + if( !tokenParsed ) + return InternalParseResult::runtimeError( "Unrecognised token: " + result.value().remainingTokens()->token ); + } + // !TBD Check missing required options + return result; + } + }; + + template<typename DerivedT> + template<typename T> + auto ComposableParserImpl<DerivedT>::operator|( T const &other ) const -> Parser { + return Parser() | static_cast<DerivedT const &>( *this ) | other; + } +} // namespace detail + +// A Combined parser +using detail::Parser; + +// A parser for options +using detail::Opt; + +// A parser for arguments +using detail::Arg; + +// Wrapper for argc, argv from main() +using detail::Args; + +// Specifies the name of the executable +using detail::ExeName; + +// Convenience wrapper for option parser that specifies the help option +using detail::Help; + +// enum of result types from a parse +using detail::ParseResultType; + +// Result type for parser operation +using detail::ParserResult; + +}} // namespace Catch::clara + +// end clara.hpp +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// Restore Clara's value for console width, if present +#ifdef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#undef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#endif + +// end catch_clara.h +namespace Catch { + + clara::Parser makeCommandLineParser( ConfigData& config ); + +} // end namespace Catch + +// end catch_commandline.h +#include <fstream> +#include <ctime> + +namespace Catch { + + clara::Parser makeCommandLineParser( ConfigData& config ) { + + using namespace clara; + + auto const setWarning = [&]( std::string const& warning ) { + auto warningSet = [&]() { + if( warning == "NoAssertions" ) + return WarnAbout::NoAssertions; + + if ( warning == "NoTests" ) + return WarnAbout::NoTests; + + return WarnAbout::Nothing; + }(); + + if (warningSet == WarnAbout::Nothing) + return ParserResult::runtimeError( "Unrecognised warning: '" + warning + "'" ); + config.warnings = static_cast<WarnAbout::What>( config.warnings | warningSet ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const loadTestNamesFromFile = [&]( std::string const& filename ) { + std::ifstream f( filename.c_str() ); + if( !f.is_open() ) + return ParserResult::runtimeError( "Unable to load input file: '" + filename + "'" ); + + std::string line; + while( std::getline( f, line ) ) { + line = trim(line); + if( !line.empty() && !startsWith( line, '#' ) ) { + if( !startsWith( line, '"' ) ) + line = '"' + line + '"'; + config.testsOrTags.push_back( line + ',' ); + } + } + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setTestOrder = [&]( std::string const& order ) { + if( startsWith( "declared", order ) ) + config.runOrder = RunTests::InDeclarationOrder; + else if( startsWith( "lexical", order ) ) + config.runOrder = RunTests::InLexicographicalOrder; + else if( startsWith( "random", order ) ) + config.runOrder = RunTests::InRandomOrder; + else + return clara::ParserResult::runtimeError( "Unrecognised ordering: '" + order + "'" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setRngSeed = [&]( std::string const& seed ) { + if( seed != "time" ) + return clara::detail::convertInto( seed, config.rngSeed ); + config.rngSeed = static_cast<unsigned int>( std::time(nullptr) ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setColourUsage = [&]( std::string const& useColour ) { + auto mode = toLower( useColour ); + + if( mode == "yes" ) + config.useColour = UseColour::Yes; + else if( mode == "no" ) + config.useColour = UseColour::No; + else if( mode == "auto" ) + config.useColour = UseColour::Auto; + else + return ParserResult::runtimeError( "colour mode must be one of: auto, yes or no. '" + useColour + "' not recognised" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setWaitForKeypress = [&]( std::string const& keypress ) { + auto keypressLc = toLower( keypress ); + if( keypressLc == "start" ) + config.waitForKeypress = WaitForKeypress::BeforeStart; + else if( keypressLc == "exit" ) + config.waitForKeypress = WaitForKeypress::BeforeExit; + else if( keypressLc == "both" ) + config.waitForKeypress = WaitForKeypress::BeforeStartAndExit; + else + return ParserResult::runtimeError( "keypress argument must be one of: start, exit or both. '" + keypress + "' not recognised" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setVerbosity = [&]( std::string const& verbosity ) { + auto lcVerbosity = toLower( verbosity ); + if( lcVerbosity == "quiet" ) + config.verbosity = Verbosity::Quiet; + else if( lcVerbosity == "normal" ) + config.verbosity = Verbosity::Normal; + else if( lcVerbosity == "high" ) + config.verbosity = Verbosity::High; + else + return ParserResult::runtimeError( "Unrecognised verbosity, '" + verbosity + "'" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setReporter = [&]( std::string const& reporter ) { + IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories(); + + auto lcReporter = toLower( reporter ); + auto result = factories.find( lcReporter ); + + if( factories.end() != result ) + config.reporterName = lcReporter; + else + return ParserResult::runtimeError( "Unrecognized reporter, '" + reporter + "'. Check available with --list-reporters" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + + auto cli + = ExeName( config.processName ) + | Help( config.showHelp ) + | Opt( config.listTests ) + ["-l"]["--list-tests"] + ( "list all/matching test cases" ) + | Opt( config.listTags ) + ["-t"]["--list-tags"] + ( "list all/matching tags" ) + | Opt( config.showSuccessfulTests ) + ["-s"]["--success"] + ( "include successful tests in output" ) + | Opt( config.shouldDebugBreak ) + ["-b"]["--break"] + ( "break into debugger on failure" ) + | Opt( config.noThrow ) + ["-e"]["--nothrow"] + ( "skip exception tests" ) + | Opt( config.showInvisibles ) + ["-i"]["--invisibles"] + ( "show invisibles (tabs, newlines)" ) + | Opt( config.outputFilename, "filename" ) + ["-o"]["--out"] + ( "output filename" ) + | Opt( setReporter, "name" ) + ["-r"]["--reporter"] + ( "reporter to use (defaults to console)" ) + | Opt( config.name, "name" ) + ["-n"]["--name"] + ( "suite name" ) + | Opt( [&]( bool ){ config.abortAfter = 1; } ) + ["-a"]["--abort"] + ( "abort at first failure" ) + | Opt( [&]( int x ){ config.abortAfter = x; }, "no. failures" ) + ["-x"]["--abortx"] + ( "abort after x failures" ) + | Opt( setWarning, "warning name" ) + ["-w"]["--warn"] + ( "enable warnings" ) + | Opt( [&]( bool flag ) { config.showDurations = flag ? ShowDurations::Always : ShowDurations::Never; }, "yes|no" ) + ["-d"]["--durations"] + ( "show test durations" ) + | Opt( loadTestNamesFromFile, "filename" ) + ["-f"]["--input-file"] + ( "load test names to run from a file" ) + | Opt( config.filenamesAsTags ) + ["-#"]["--filenames-as-tags"] + ( "adds a tag for the filename" ) + | Opt( config.sectionsToRun, "section name" ) + ["-c"]["--section"] + ( "specify section to run" ) + | Opt( setVerbosity, "quiet|normal|high" ) + ["-v"]["--verbosity"] + ( "set output verbosity" ) + | Opt( config.listTestNamesOnly ) + ["--list-test-names-only"] + ( "list all/matching test cases names only" ) + | Opt( config.listReporters ) + ["--list-reporters"] + ( "list all reporters" ) + | Opt( setTestOrder, "decl|lex|rand" ) + ["--order"] + ( "test case order (defaults to decl)" ) + | Opt( setRngSeed, "'time'|number" ) + ["--rng-seed"] + ( "set a specific seed for random numbers" ) + | Opt( setColourUsage, "yes|no" ) + ["--use-colour"] + ( "should output be colourised" ) + | Opt( config.libIdentify ) + ["--libidentify"] + ( "report name and version according to libidentify standard" ) + | Opt( setWaitForKeypress, "start|exit|both" ) + ["--wait-for-keypress"] + ( "waits for a keypress before exiting" ) + | Opt( config.benchmarkResolutionMultiple, "multiplier" ) + ["--benchmark-resolution-multiple"] + ( "multiple of clock resolution to run benchmarks" ) + + | Arg( config.testsOrTags, "test name|pattern|tags" ) + ( "which test or tests to use" ); + + return cli; + } + +} // end namespace Catch +// end catch_commandline.cpp +// start catch_common.cpp + +#include <cstring> +#include <ostream> + +namespace Catch { + + bool SourceLineInfo::empty() const noexcept { + return file[0] == '\0'; + } + bool SourceLineInfo::operator == ( SourceLineInfo const& other ) const noexcept { + return line == other.line && (file == other.file || std::strcmp(file, other.file) == 0); + } + bool SourceLineInfo::operator < ( SourceLineInfo const& other ) const noexcept { + // We can assume that the same file will usually have the same pointer. + // Thus, if the pointers are the same, there is no point in calling the strcmp + return line < other.line || ( line == other.line && file != other.file && (std::strcmp(file, other.file) < 0)); + } + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ) { +#ifndef __GNUG__ + os << info.file << '(' << info.line << ')'; +#else + os << info.file << ':' << info.line; +#endif + return os; + } + + std::string StreamEndStop::operator+() const { + return std::string(); + } + + NonCopyable::NonCopyable() = default; + NonCopyable::~NonCopyable() = default; + +} +// end catch_common.cpp +// start catch_config.cpp + +namespace Catch { + + Config::Config( ConfigData const& data ) + : m_data( data ), + m_stream( openStream() ) + { + TestSpecParser parser(ITagAliasRegistry::get()); + if (data.testsOrTags.empty()) { + parser.parse("~[.]"); // All not hidden tests + } + else { + m_hasTestFilters = true; + for( auto const& testOrTags : data.testsOrTags ) + parser.parse( testOrTags ); + } + m_testSpec = parser.testSpec(); + } + + std::string const& Config::getFilename() const { + return m_data.outputFilename ; + } + + bool Config::listTests() const { return m_data.listTests; } + bool Config::listTestNamesOnly() const { return m_data.listTestNamesOnly; } + bool Config::listTags() const { return m_data.listTags; } + bool Config::listReporters() const { return m_data.listReporters; } + + std::string Config::getProcessName() const { return m_data.processName; } + std::string const& Config::getReporterName() const { return m_data.reporterName; } + + std::vector<std::string> const& Config::getTestsOrTags() const { return m_data.testsOrTags; } + std::vector<std::string> const& Config::getSectionsToRun() const { return m_data.sectionsToRun; } + + TestSpec const& Config::testSpec() const { return m_testSpec; } + bool Config::hasTestFilters() const { return m_hasTestFilters; } + + bool Config::showHelp() const { return m_data.showHelp; } + + // IConfig interface + bool Config::allowThrows() const { return !m_data.noThrow; } + std::ostream& Config::stream() const { return m_stream->stream(); } + std::string Config::name() const { return m_data.name.empty() ? m_data.processName : m_data.name; } + bool Config::includeSuccessfulResults() const { return m_data.showSuccessfulTests; } + bool Config::warnAboutMissingAssertions() const { return !!(m_data.warnings & WarnAbout::NoAssertions); } + bool Config::warnAboutNoTests() const { return !!(m_data.warnings & WarnAbout::NoTests); } + ShowDurations::OrNot Config::showDurations() const { return m_data.showDurations; } + RunTests::InWhatOrder Config::runOrder() const { return m_data.runOrder; } + unsigned int Config::rngSeed() const { return m_data.rngSeed; } + int Config::benchmarkResolutionMultiple() const { return m_data.benchmarkResolutionMultiple; } + UseColour::YesOrNo Config::useColour() const { return m_data.useColour; } + bool Config::shouldDebugBreak() const { return m_data.shouldDebugBreak; } + int Config::abortAfter() const { return m_data.abortAfter; } + bool Config::showInvisibles() const { return m_data.showInvisibles; } + Verbosity Config::verbosity() const { return m_data.verbosity; } + + IStream const* Config::openStream() { + return Catch::makeStream(m_data.outputFilename); + } + +} // end namespace Catch +// end catch_config.cpp +// start catch_console_colour.cpp + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +#endif + +// start catch_errno_guard.h + +namespace Catch { + + class ErrnoGuard { + public: + ErrnoGuard(); + ~ErrnoGuard(); + private: + int m_oldErrno; + }; + +} + +// end catch_errno_guard.h +#include <sstream> + +namespace Catch { + namespace { + + struct IColourImpl { + virtual ~IColourImpl() = default; + virtual void use( Colour::Code _colourCode ) = 0; + }; + + struct NoColourImpl : IColourImpl { + void use( Colour::Code ) {} + + static IColourImpl* instance() { + static NoColourImpl s_instance; + return &s_instance; + } + }; + + } // anon namespace +} // namespace Catch + +#if !defined( CATCH_CONFIG_COLOUR_NONE ) && !defined( CATCH_CONFIG_COLOUR_WINDOWS ) && !defined( CATCH_CONFIG_COLOUR_ANSI ) +# ifdef CATCH_PLATFORM_WINDOWS +# define CATCH_CONFIG_COLOUR_WINDOWS +# else +# define CATCH_CONFIG_COLOUR_ANSI +# endif +#endif + +#if defined ( CATCH_CONFIG_COLOUR_WINDOWS ) ///////////////////////////////////////// + +namespace Catch { +namespace { + + class Win32ColourImpl : public IColourImpl { + public: + Win32ColourImpl() : stdoutHandle( GetStdHandle(STD_OUTPUT_HANDLE) ) + { + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo( stdoutHandle, &csbiInfo ); + originalForegroundAttributes = csbiInfo.wAttributes & ~( BACKGROUND_GREEN | BACKGROUND_RED | BACKGROUND_BLUE | BACKGROUND_INTENSITY ); + originalBackgroundAttributes = csbiInfo.wAttributes & ~( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE | FOREGROUND_INTENSITY ); + } + + virtual void use( Colour::Code _colourCode ) override { + switch( _colourCode ) { + case Colour::None: return setTextAttribute( originalForegroundAttributes ); + case Colour::White: return setTextAttribute( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + case Colour::Red: return setTextAttribute( FOREGROUND_RED ); + case Colour::Green: return setTextAttribute( FOREGROUND_GREEN ); + case Colour::Blue: return setTextAttribute( FOREGROUND_BLUE ); + case Colour::Cyan: return setTextAttribute( FOREGROUND_BLUE | FOREGROUND_GREEN ); + case Colour::Yellow: return setTextAttribute( FOREGROUND_RED | FOREGROUND_GREEN ); + case Colour::Grey: return setTextAttribute( 0 ); + + case Colour::LightGrey: return setTextAttribute( FOREGROUND_INTENSITY ); + case Colour::BrightRed: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED ); + case Colour::BrightGreen: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN ); + case Colour::BrightWhite: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + case Colour::BrightYellow: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED | FOREGROUND_GREEN ); + + case Colour::Bright: CATCH_INTERNAL_ERROR( "not a colour" ); + + default: + CATCH_ERROR( "Unknown colour requested" ); + } + } + + private: + void setTextAttribute( WORD _textAttribute ) { + SetConsoleTextAttribute( stdoutHandle, _textAttribute | originalBackgroundAttributes ); + } + HANDLE stdoutHandle; + WORD originalForegroundAttributes; + WORD originalBackgroundAttributes; + }; + + IColourImpl* platformColourInstance() { + static Win32ColourImpl s_instance; + + IConfigPtr config = getCurrentContext().getConfig(); + UseColour::YesOrNo colourMode = config + ? config->useColour() + : UseColour::Auto; + if( colourMode == UseColour::Auto ) + colourMode = UseColour::Yes; + return colourMode == UseColour::Yes + ? &s_instance + : NoColourImpl::instance(); + } + +} // end anon namespace +} // end namespace Catch + +#elif defined( CATCH_CONFIG_COLOUR_ANSI ) ////////////////////////////////////// + +#include <unistd.h> + +namespace Catch { +namespace { + + // use POSIX/ ANSI console terminal codes + // Thanks to Adam Strzelecki for original contribution + // (http://github.com/nanoant) + // https://github.com/philsquared/Catch/pull/131 + class PosixColourImpl : public IColourImpl { + public: + virtual void use( Colour::Code _colourCode ) override { + switch( _colourCode ) { + case Colour::None: + case Colour::White: return setColour( "[0m" ); + case Colour::Red: return setColour( "[0;31m" ); + case Colour::Green: return setColour( "[0;32m" ); + case Colour::Blue: return setColour( "[0;34m" ); + case Colour::Cyan: return setColour( "[0;36m" ); + case Colour::Yellow: return setColour( "[0;33m" ); + case Colour::Grey: return setColour( "[1;30m" ); + + case Colour::LightGrey: return setColour( "[0;37m" ); + case Colour::BrightRed: return setColour( "[1;31m" ); + case Colour::BrightGreen: return setColour( "[1;32m" ); + case Colour::BrightWhite: return setColour( "[1;37m" ); + case Colour::BrightYellow: return setColour( "[1;33m" ); + + case Colour::Bright: CATCH_INTERNAL_ERROR( "not a colour" ); + default: CATCH_INTERNAL_ERROR( "Unknown colour requested" ); + } + } + static IColourImpl* instance() { + static PosixColourImpl s_instance; + return &s_instance; + } + + private: + void setColour( const char* _escapeCode ) { + getCurrentContext().getConfig()->stream() + << '\033' << _escapeCode; + } + }; + + bool useColourOnPlatform() { + return +#ifdef CATCH_PLATFORM_MAC + !isDebuggerActive() && +#endif +#if !(defined(__DJGPP__) && defined(__STRICT_ANSI__)) + isatty(STDOUT_FILENO) +#else + false +#endif + ; + } + IColourImpl* platformColourInstance() { + ErrnoGuard guard; + IConfigPtr config = getCurrentContext().getConfig(); + UseColour::YesOrNo colourMode = config + ? config->useColour() + : UseColour::Auto; + if( colourMode == UseColour::Auto ) + colourMode = useColourOnPlatform() + ? UseColour::Yes + : UseColour::No; + return colourMode == UseColour::Yes + ? PosixColourImpl::instance() + : NoColourImpl::instance(); + } + +} // end anon namespace +} // end namespace Catch + +#else // not Windows or ANSI /////////////////////////////////////////////// + +namespace Catch { + + static IColourImpl* platformColourInstance() { return NoColourImpl::instance(); } + +} // end namespace Catch + +#endif // Windows/ ANSI/ None + +namespace Catch { + + Colour::Colour( Code _colourCode ) { use( _colourCode ); } + Colour::Colour( Colour&& rhs ) noexcept { + m_moved = rhs.m_moved; + rhs.m_moved = true; + } + Colour& Colour::operator=( Colour&& rhs ) noexcept { + m_moved = rhs.m_moved; + rhs.m_moved = true; + return *this; + } + + Colour::~Colour(){ if( !m_moved ) use( None ); } + + void Colour::use( Code _colourCode ) { + static IColourImpl* impl = platformColourInstance(); + impl->use( _colourCode ); + } + + std::ostream& operator << ( std::ostream& os, Colour const& ) { + return os; + } + +} // end namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + +// end catch_console_colour.cpp +// start catch_context.cpp + +namespace Catch { + + class Context : public IMutableContext, NonCopyable { + + public: // IContext + virtual IResultCapture* getResultCapture() override { + return m_resultCapture; + } + virtual IRunner* getRunner() override { + return m_runner; + } + + virtual IConfigPtr const& getConfig() const override { + return m_config; + } + + virtual ~Context() override; + + public: // IMutableContext + virtual void setResultCapture( IResultCapture* resultCapture ) override { + m_resultCapture = resultCapture; + } + virtual void setRunner( IRunner* runner ) override { + m_runner = runner; + } + virtual void setConfig( IConfigPtr const& config ) override { + m_config = config; + } + + friend IMutableContext& getCurrentMutableContext(); + + private: + IConfigPtr m_config; + IRunner* m_runner = nullptr; + IResultCapture* m_resultCapture = nullptr; + }; + + IMutableContext *IMutableContext::currentContext = nullptr; + + void IMutableContext::createContext() + { + currentContext = new Context(); + } + + void cleanUpContext() { + delete IMutableContext::currentContext; + IMutableContext::currentContext = nullptr; + } + IContext::~IContext() = default; + IMutableContext::~IMutableContext() = default; + Context::~Context() = default; +} +// end catch_context.cpp +// start catch_debug_console.cpp + +// start catch_debug_console.h + +#include <string> + +namespace Catch { + void writeToDebugConsole( std::string const& text ); +} + +// end catch_debug_console.h +#ifdef CATCH_PLATFORM_WINDOWS + + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + ::OutputDebugStringA( text.c_str() ); + } + } + +#else + + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + // !TBD: Need a version for Mac/ XCode and other IDEs + Catch::cout() << text; + } + } + +#endif // Platform +// end catch_debug_console.cpp +// start catch_debugger.cpp + +#ifdef CATCH_PLATFORM_MAC + +# include <assert.h> +# include <stdbool.h> +# include <sys/types.h> +# include <unistd.h> +# include <sys/sysctl.h> +# include <cstddef> +# include <ostream> + +namespace Catch { + + // The following function is taken directly from the following technical note: + // http://developer.apple.com/library/mac/#qa/qa2004/qa1361.html + + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive(){ + + int mib[4]; + struct kinfo_proc info; + std::size_t size; + + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + + info.kp_proc.p_flag = 0; + + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + + // Call sysctl. + + size = sizeof(info); + if( sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, nullptr, 0) != 0 ) { + Catch::cerr() << "\n** Call to sysctl failed - unable to determine if debugger is active **\n" << std::endl; + return false; + } + + // We're being debugged if the P_TRACED flag is set. + + return ( (info.kp_proc.p_flag & P_TRACED) != 0 ); + } + } // namespace Catch + +#elif defined(CATCH_PLATFORM_LINUX) + #include <fstream> + #include <string> + + namespace Catch{ + // The standard POSIX way of detecting a debugger is to attempt to + // ptrace() the process, but this needs to be done from a child and not + // this process itself to still allow attaching to this process later + // if wanted, so is rather heavy. Under Linux we have the PID of the + // "debugger" (which doesn't need to be gdb, of course, it could also + // be strace, for example) in /proc/$PID/status, so just get it from + // there instead. + bool isDebuggerActive(){ + // Libstdc++ has a bug, where std::ifstream sets errno to 0 + // This way our users can properly assert over errno values + ErrnoGuard guard; + std::ifstream in("/proc/self/status"); + for( std::string line; std::getline(in, line); ) { + static const int PREFIX_LEN = 11; + if( line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0 ) { + // We're traced if the PID is not 0 and no other PID starts + // with 0 digit, so it's enough to check for just a single + // character. + return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; + } + } + + return false; + } + } // namespace Catch +#elif defined(_MSC_VER) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#else + namespace Catch { + bool isDebuggerActive() { return false; } + } +#endif // Platform +// end catch_debugger.cpp +// start catch_decomposer.cpp + +namespace Catch { + + ITransientExpression::~ITransientExpression() = default; + + void formatReconstructedExpression( std::ostream &os, std::string const& lhs, StringRef op, std::string const& rhs ) { + if( lhs.size() + rhs.size() < 40 && + lhs.find('\n') == std::string::npos && + rhs.find('\n') == std::string::npos ) + os << lhs << " " << op << " " << rhs; + else + os << lhs << "\n" << op << "\n" << rhs; + } +} +// end catch_decomposer.cpp +// start catch_enforce.cpp + +namespace Catch { +#if defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) && !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS_CUSTOM_HANDLER) + [[noreturn]] + void throw_exception(std::exception const& e) { + Catch::cerr() << "Catch will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; + std::terminate(); + } +#endif +} // namespace Catch; +// end catch_enforce.cpp +// start catch_errno_guard.cpp + +#include <cerrno> + +namespace Catch { + ErrnoGuard::ErrnoGuard():m_oldErrno(errno){} + ErrnoGuard::~ErrnoGuard() { errno = m_oldErrno; } +} +// end catch_errno_guard.cpp +// start catch_exception_translator_registry.cpp + +// start catch_exception_translator_registry.h + +#include <vector> +#include <string> +#include <memory> + +namespace Catch { + + class ExceptionTranslatorRegistry : public IExceptionTranslatorRegistry { + public: + ~ExceptionTranslatorRegistry(); + virtual void registerTranslator( const IExceptionTranslator* translator ); + virtual std::string translateActiveException() const override; + std::string tryTranslators() const; + + private: + std::vector<std::unique_ptr<IExceptionTranslator const>> m_translators; + }; +} + +// end catch_exception_translator_registry.h +#ifdef __OBJC__ +#import "Foundation/Foundation.h" +#endif + +namespace Catch { + + ExceptionTranslatorRegistry::~ExceptionTranslatorRegistry() { + } + + void ExceptionTranslatorRegistry::registerTranslator( const IExceptionTranslator* translator ) { + m_translators.push_back( std::unique_ptr<const IExceptionTranslator>( translator ) ); + } + +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + std::string ExceptionTranslatorRegistry::translateActiveException() const { + try { +#ifdef __OBJC__ + // In Objective-C try objective-c exceptions first + @try { + return tryTranslators(); + } + @catch (NSException *exception) { + return Catch::Detail::stringify( [exception description] ); + } +#else + // Compiling a mixed mode project with MSVC means that CLR + // exceptions will be caught in (...) as well. However, these + // do not fill-in std::current_exception and thus lead to crash + // when attempting rethrow. + // /EHa switch also causes structured exceptions to be caught + // here, but they fill-in current_exception properly, so + // at worst the output should be a little weird, instead of + // causing a crash. + if (std::current_exception() == nullptr) { + return "Non C++ exception. Possibly a CLR exception."; + } + return tryTranslators(); +#endif + } + catch( TestFailureException& ) { + std::rethrow_exception(std::current_exception()); + } + catch( std::exception& ex ) { + return ex.what(); + } + catch( std::string& msg ) { + return msg; + } + catch( const char* msg ) { + return msg; + } + catch(...) { + return "Unknown exception"; + } + } + + std::string ExceptionTranslatorRegistry::tryTranslators() const { + if (m_translators.empty()) { + std::rethrow_exception(std::current_exception()); + } else { + return m_translators[0]->translate(m_translators.begin() + 1, m_translators.end()); + } + } + +#else // ^^ Exceptions are enabled // Exceptions are disabled vv + std::string ExceptionTranslatorRegistry::translateActiveException() const { + CATCH_INTERNAL_ERROR("Attempted to translate active exception under CATCH_CONFIG_DISABLE_EXCEPTIONS!"); + } + + std::string ExceptionTranslatorRegistry::tryTranslators() const { + CATCH_INTERNAL_ERROR("Attempted to use exception translators under CATCH_CONFIG_DISABLE_EXCEPTIONS!"); + } +#endif + +} +// end catch_exception_translator_registry.cpp +// start catch_fatal_condition.cpp + +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#endif + +#if defined( CATCH_CONFIG_WINDOWS_SEH ) || defined( CATCH_CONFIG_POSIX_SIGNALS ) + +namespace { + // Report the error condition + void reportFatal( char const * const message ) { + Catch::getCurrentContext().getResultCapture()->handleFatalErrorCondition( message ); + } +} + +#endif // signals/SEH handling + +#if defined( CATCH_CONFIG_WINDOWS_SEH ) + +namespace Catch { + struct SignalDefs { DWORD id; const char* name; }; + + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + static SignalDefs signalDefs[] = { + { EXCEPTION_ILLEGAL_INSTRUCTION, "SIGILL - Illegal instruction signal" }, + { EXCEPTION_STACK_OVERFLOW, "SIGSEGV - Stack overflow" }, + { EXCEPTION_ACCESS_VIOLATION, "SIGSEGV - Segmentation violation signal" }, + { EXCEPTION_INT_DIVIDE_BY_ZERO, "Divide by zero error" }, + }; + + LONG CALLBACK FatalConditionHandler::handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) { + for (auto const& def : signalDefs) { + if (ExceptionInfo->ExceptionRecord->ExceptionCode == def.id) { + reportFatal(def.name); + } + } + // If its not an exception we care about, pass it along. + // This stops us from eating debugger breaks etc. + return EXCEPTION_CONTINUE_SEARCH; + } + + FatalConditionHandler::FatalConditionHandler() { + isSet = true; + // 32k seems enough for Catch to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + exceptionHandlerHandle = nullptr; + // Register as first handler in current chain + exceptionHandlerHandle = AddVectoredExceptionHandler(1, handleVectoredException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + } + + void FatalConditionHandler::reset() { + if (isSet) { + RemoveVectoredExceptionHandler(exceptionHandlerHandle); + SetThreadStackGuarantee(&guaranteeSize); + exceptionHandlerHandle = nullptr; + isSet = false; + } + } + + FatalConditionHandler::~FatalConditionHandler() { + reset(); + } + +bool FatalConditionHandler::isSet = false; +ULONG FatalConditionHandler::guaranteeSize = 0; +PVOID FatalConditionHandler::exceptionHandlerHandle = nullptr; + +} // namespace Catch + +#elif defined( CATCH_CONFIG_POSIX_SIGNALS ) + +namespace Catch { + + struct SignalDefs { + int id; + const char* name; + }; + + // 32kb for the alternate stack seems to be sufficient. However, this value + // is experimentally determined, so that's not guaranteed. + constexpr static std::size_t sigStackSize = 32768 >= MINSIGSTKSZ ? 32768 : MINSIGSTKSZ; + + static SignalDefs signalDefs[] = { + { SIGINT, "SIGINT - Terminal interrupt signal" }, + { SIGILL, "SIGILL - Illegal instruction signal" }, + { SIGFPE, "SIGFPE - Floating point error signal" }, + { SIGSEGV, "SIGSEGV - Segmentation violation signal" }, + { SIGTERM, "SIGTERM - Termination request signal" }, + { SIGABRT, "SIGABRT - Abort (abnormal termination) signal" } + }; + + void FatalConditionHandler::handleSignal( int sig ) { + char const * name = "<unknown signal>"; + for (auto const& def : signalDefs) { + if (sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise( sig ); + } + + FatalConditionHandler::FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = sigStackSize; + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = { }; + + sa.sa_handler = handleSignal; + sa.sa_flags = SA_ONSTACK; + for (std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + FatalConditionHandler::~FatalConditionHandler() { + reset(); + } + + void FatalConditionHandler::reset() { + if( isSet ) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for( std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i ) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + isSet = false; + } + } + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[sizeof(signalDefs)/sizeof(SignalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + char FatalConditionHandler::altStackMem[sigStackSize] = {}; + +} // namespace Catch + +#else + +namespace Catch { + void FatalConditionHandler::reset() {} +} + +#endif // signals/SEH handling + +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +// end catch_fatal_condition.cpp +// start catch_generators.cpp + +// start catch_random_number_generator.h + +#include <algorithm> +#include <random> + +namespace Catch { + + struct IConfig; + + std::mt19937& rng(); + void seedRng( IConfig const& config ); + unsigned int rngSeed(); + +} + +// end catch_random_number_generator.h +#include <limits> +#include <set> + +namespace Catch { + +IGeneratorTracker::~IGeneratorTracker() {} + +const char* GeneratorException::what() const noexcept { + return m_msg; +} + +namespace Generators { + + GeneratorUntypedBase::~GeneratorUntypedBase() {} + + auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& { + return getResultCapture().acquireGeneratorTracker( lineInfo ); + } + +} // namespace Generators +} // namespace Catch +// end catch_generators.cpp +// start catch_interfaces_capture.cpp + +namespace Catch { + IResultCapture::~IResultCapture() = default; +} +// end catch_interfaces_capture.cpp +// start catch_interfaces_config.cpp + +namespace Catch { + IConfig::~IConfig() = default; +} +// end catch_interfaces_config.cpp +// start catch_interfaces_exception.cpp + +namespace Catch { + IExceptionTranslator::~IExceptionTranslator() = default; + IExceptionTranslatorRegistry::~IExceptionTranslatorRegistry() = default; +} +// end catch_interfaces_exception.cpp +// start catch_interfaces_registry_hub.cpp + +namespace Catch { + IRegistryHub::~IRegistryHub() = default; + IMutableRegistryHub::~IMutableRegistryHub() = default; +} +// end catch_interfaces_registry_hub.cpp +// start catch_interfaces_reporter.cpp + +// start catch_reporter_listening.h + +namespace Catch { + + class ListeningReporter : public IStreamingReporter { + using Reporters = std::vector<IStreamingReporterPtr>; + Reporters m_listeners; + IStreamingReporterPtr m_reporter = nullptr; + ReporterPreferences m_preferences; + + public: + ListeningReporter(); + + void addListener( IStreamingReporterPtr&& listener ); + void addReporter( IStreamingReporterPtr&& reporter ); + + public: // IStreamingReporter + + ReporterPreferences getPreferences() const override; + + void noMatchingTestCases( std::string const& spec ) override; + + static std::set<Verbosity> getSupportedVerbosities(); + + void benchmarkStarting( BenchmarkInfo const& benchmarkInfo ) override; + void benchmarkEnded( BenchmarkStats const& benchmarkStats ) override; + + void testRunStarting( TestRunInfo const& testRunInfo ) override; + void testGroupStarting( GroupInfo const& groupInfo ) override; + void testCaseStarting( TestCaseInfo const& testInfo ) override; + void sectionStarting( SectionInfo const& sectionInfo ) override; + void assertionStarting( AssertionInfo const& assertionInfo ) override; + + // The return value indicates if the messages buffer should be cleared: + bool assertionEnded( AssertionStats const& assertionStats ) override; + void sectionEnded( SectionStats const& sectionStats ) override; + void testCaseEnded( TestCaseStats const& testCaseStats ) override; + void testGroupEnded( TestGroupStats const& testGroupStats ) override; + void testRunEnded( TestRunStats const& testRunStats ) override; + + void skipTest( TestCaseInfo const& testInfo ) override; + bool isMulti() const override; + + }; + +} // end namespace Catch + +// end catch_reporter_listening.h +namespace Catch { + + ReporterConfig::ReporterConfig( IConfigPtr const& _fullConfig ) + : m_stream( &_fullConfig->stream() ), m_fullConfig( _fullConfig ) {} + + ReporterConfig::ReporterConfig( IConfigPtr const& _fullConfig, std::ostream& _stream ) + : m_stream( &_stream ), m_fullConfig( _fullConfig ) {} + + std::ostream& ReporterConfig::stream() const { return *m_stream; } + IConfigPtr ReporterConfig::fullConfig() const { return m_fullConfig; } + + TestRunInfo::TestRunInfo( std::string const& _name ) : name( _name ) {} + + GroupInfo::GroupInfo( std::string const& _name, + std::size_t _groupIndex, + std::size_t _groupsCount ) + : name( _name ), + groupIndex( _groupIndex ), + groupsCounts( _groupsCount ) + {} + + AssertionStats::AssertionStats( AssertionResult const& _assertionResult, + std::vector<MessageInfo> const& _infoMessages, + Totals const& _totals ) + : assertionResult( _assertionResult ), + infoMessages( _infoMessages ), + totals( _totals ) + { + assertionResult.m_resultData.lazyExpression.m_transientExpression = _assertionResult.m_resultData.lazyExpression.m_transientExpression; + + if( assertionResult.hasMessage() ) { + // Copy message into messages list. + // !TBD This should have been done earlier, somewhere + MessageBuilder builder( assertionResult.getTestMacroName(), assertionResult.getSourceInfo(), assertionResult.getResultType() ); + builder << assertionResult.getMessage(); + builder.m_info.message = builder.m_stream.str(); + + infoMessages.push_back( builder.m_info ); + } + } + + AssertionStats::~AssertionStats() = default; + + SectionStats::SectionStats( SectionInfo const& _sectionInfo, + Counts const& _assertions, + double _durationInSeconds, + bool _missingAssertions ) + : sectionInfo( _sectionInfo ), + assertions( _assertions ), + durationInSeconds( _durationInSeconds ), + missingAssertions( _missingAssertions ) + {} + + SectionStats::~SectionStats() = default; + + TestCaseStats::TestCaseStats( TestCaseInfo const& _testInfo, + Totals const& _totals, + std::string const& _stdOut, + std::string const& _stdErr, + bool _aborting ) + : testInfo( _testInfo ), + totals( _totals ), + stdOut( _stdOut ), + stdErr( _stdErr ), + aborting( _aborting ) + {} + + TestCaseStats::~TestCaseStats() = default; + + TestGroupStats::TestGroupStats( GroupInfo const& _groupInfo, + Totals const& _totals, + bool _aborting ) + : groupInfo( _groupInfo ), + totals( _totals ), + aborting( _aborting ) + {} + + TestGroupStats::TestGroupStats( GroupInfo const& _groupInfo ) + : groupInfo( _groupInfo ), + aborting( false ) + {} + + TestGroupStats::~TestGroupStats() = default; + + TestRunStats::TestRunStats( TestRunInfo const& _runInfo, + Totals const& _totals, + bool _aborting ) + : runInfo( _runInfo ), + totals( _totals ), + aborting( _aborting ) + {} + + TestRunStats::~TestRunStats() = default; + + void IStreamingReporter::fatalErrorEncountered( StringRef ) {} + bool IStreamingReporter::isMulti() const { return false; } + + IReporterFactory::~IReporterFactory() = default; + IReporterRegistry::~IReporterRegistry() = default; + +} // end namespace Catch +// end catch_interfaces_reporter.cpp +// start catch_interfaces_runner.cpp + +namespace Catch { + IRunner::~IRunner() = default; +} +// end catch_interfaces_runner.cpp +// start catch_interfaces_testcase.cpp + +namespace Catch { + ITestInvoker::~ITestInvoker() = default; + ITestCaseRegistry::~ITestCaseRegistry() = default; +} +// end catch_interfaces_testcase.cpp +// start catch_leak_detector.cpp + +#ifdef CATCH_CONFIG_WINDOWS_CRTDBG +#include <crtdbg.h> + +namespace Catch { + + LeakDetector::LeakDetector() { + int flag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); + flag |= _CRTDBG_LEAK_CHECK_DF; + flag |= _CRTDBG_ALLOC_MEM_DF; + _CrtSetDbgFlag(flag); + _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR); + // Change this to leaking allocation's number to break there + _CrtSetBreakAlloc(-1); + } +} + +#else + + Catch::LeakDetector::LeakDetector() {} + +#endif + +Catch::LeakDetector::~LeakDetector() { + Catch::cleanUp(); +} +// end catch_leak_detector.cpp +// start catch_list.cpp + +// start catch_list.h + +#include <set> + +namespace Catch { + + std::size_t listTests( Config const& config ); + + std::size_t listTestsNamesOnly( Config const& config ); + + struct TagInfo { + void add( std::string const& spelling ); + std::string all() const; + + std::set<std::string> spellings; + std::size_t count = 0; + }; + + std::size_t listTags( Config const& config ); + + std::size_t listReporters(); + + Option<std::size_t> list( std::shared_ptr<Config> const& config ); + +} // end namespace Catch + +// end catch_list.h +// start catch_text.h + +namespace Catch { + using namespace clara::TextFlow; +} + +// end catch_text.h +#include <limits> +#include <algorithm> +#include <iomanip> + +namespace Catch { + + std::size_t listTests( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( config.hasTestFilters() ) + Catch::cout() << "Matching test cases:\n"; + else { + Catch::cout() << "All available test cases:\n"; + } + + auto matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( auto const& testCaseInfo : matchedTestCases ) { + Colour::Code colour = testCaseInfo.isHidden() + ? Colour::SecondaryText + : Colour::None; + Colour colourGuard( colour ); + + Catch::cout() << Column( testCaseInfo.name ).initialIndent( 2 ).indent( 4 ) << "\n"; + if( config.verbosity() >= Verbosity::High ) { + Catch::cout() << Column( Catch::Detail::stringify( testCaseInfo.lineInfo ) ).indent(4) << std::endl; + std::string description = testCaseInfo.description; + if( description.empty() ) + description = "(NO DESCRIPTION)"; + Catch::cout() << Column( description ).indent(4) << std::endl; + } + if( !testCaseInfo.tags.empty() ) + Catch::cout() << Column( testCaseInfo.tagsAsString() ).indent( 6 ) << "\n"; + } + + if( !config.hasTestFilters() ) + Catch::cout() << pluralise( matchedTestCases.size(), "test case" ) << '\n' << std::endl; + else + Catch::cout() << pluralise( matchedTestCases.size(), "matching test case" ) << '\n' << std::endl; + return matchedTestCases.size(); + } + + std::size_t listTestsNamesOnly( Config const& config ) { + TestSpec testSpec = config.testSpec(); + std::size_t matchedTests = 0; + std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( auto const& testCaseInfo : matchedTestCases ) { + matchedTests++; + if( startsWith( testCaseInfo.name, '#' ) ) + Catch::cout() << '"' << testCaseInfo.name << '"'; + else + Catch::cout() << testCaseInfo.name; + if ( config.verbosity() >= Verbosity::High ) + Catch::cout() << "\t@" << testCaseInfo.lineInfo; + Catch::cout() << std::endl; + } + return matchedTests; + } + + void TagInfo::add( std::string const& spelling ) { + ++count; + spellings.insert( spelling ); + } + + std::string TagInfo::all() const { + std::string out; + for( auto const& spelling : spellings ) + out += "[" + spelling + "]"; + return out; + } + + std::size_t listTags( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( config.hasTestFilters() ) + Catch::cout() << "Tags for matching test cases:\n"; + else { + Catch::cout() << "All available tags:\n"; + } + + std::map<std::string, TagInfo> tagCounts; + + std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( auto const& testCase : matchedTestCases ) { + for( auto const& tagName : testCase.getTestCaseInfo().tags ) { + std::string lcaseTagName = toLower( tagName ); + auto countIt = tagCounts.find( lcaseTagName ); + if( countIt == tagCounts.end() ) + countIt = tagCounts.insert( std::make_pair( lcaseTagName, TagInfo() ) ).first; + countIt->second.add( tagName ); + } + } + + for( auto const& tagCount : tagCounts ) { + ReusableStringStream rss; + rss << " " << std::setw(2) << tagCount.second.count << " "; + auto str = rss.str(); + auto wrapper = Column( tagCount.second.all() ) + .initialIndent( 0 ) + .indent( str.size() ) + .width( CATCH_CONFIG_CONSOLE_WIDTH-10 ); + Catch::cout() << str << wrapper << '\n'; + } + Catch::cout() << pluralise( tagCounts.size(), "tag" ) << '\n' << std::endl; + return tagCounts.size(); + } + + std::size_t listReporters() { + Catch::cout() << "Available reporters:\n"; + IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories(); + std::size_t maxNameLen = 0; + for( auto const& factoryKvp : factories ) + maxNameLen = (std::max)( maxNameLen, factoryKvp.first.size() ); + + for( auto const& factoryKvp : factories ) { + Catch::cout() + << Column( factoryKvp.first + ":" ) + .indent(2) + .width( 5+maxNameLen ) + + Column( factoryKvp.second->getDescription() ) + .initialIndent(0) + .indent(2) + .width( CATCH_CONFIG_CONSOLE_WIDTH - maxNameLen-8 ) + << "\n"; + } + Catch::cout() << std::endl; + return factories.size(); + } + + Option<std::size_t> list( std::shared_ptr<Config> const& config ) { + Option<std::size_t> listedCount; + getCurrentMutableContext().setConfig( config ); + if( config->listTests() ) + listedCount = listedCount.valueOr(0) + listTests( *config ); + if( config->listTestNamesOnly() ) + listedCount = listedCount.valueOr(0) + listTestsNamesOnly( *config ); + if( config->listTags() ) + listedCount = listedCount.valueOr(0) + listTags( *config ); + if( config->listReporters() ) + listedCount = listedCount.valueOr(0) + listReporters(); + return listedCount; + } + +} // end namespace Catch +// end catch_list.cpp +// start catch_matchers.cpp + +namespace Catch { +namespace Matchers { + namespace Impl { + + std::string MatcherUntypedBase::toString() const { + if( m_cachedToString.empty() ) + m_cachedToString = describe(); + return m_cachedToString; + } + + MatcherUntypedBase::~MatcherUntypedBase() = default; + + } // namespace Impl +} // namespace Matchers + +using namespace Matchers; +using Matchers::Impl::MatcherBase; + +} // namespace Catch +// end catch_matchers.cpp +// start catch_matchers_floating.cpp + +// start catch_polyfills.hpp + +namespace Catch { + bool isnan(float f); + bool isnan(double d); +} + +// end catch_polyfills.hpp +// start catch_to_string.hpp + +#include <string> + +namespace Catch { + template <typename T> + std::string to_string(T const& t) { +#if defined(CATCH_CONFIG_CPP11_TO_STRING) + return std::to_string(t); +#else + ReusableStringStream rss; + rss << t; + return rss.str(); +#endif + } +} // end namespace Catch + +// end catch_to_string.hpp +#include <cstdlib> +#include <cstdint> +#include <cstring> + +namespace Catch { +namespace Matchers { +namespace Floating { +enum class FloatingPointKind : uint8_t { + Float, + Double +}; +} +} +} + +namespace { + +template <typename T> +struct Converter; + +template <> +struct Converter<float> { + static_assert(sizeof(float) == sizeof(int32_t), "Important ULP matcher assumption violated"); + Converter(float f) { + std::memcpy(&i, &f, sizeof(f)); + } + int32_t i; +}; + +template <> +struct Converter<double> { + static_assert(sizeof(double) == sizeof(int64_t), "Important ULP matcher assumption violated"); + Converter(double d) { + std::memcpy(&i, &d, sizeof(d)); + } + int64_t i; +}; + +template <typename T> +auto convert(T t) -> Converter<T> { + return Converter<T>(t); +} + +template <typename FP> +bool almostEqualUlps(FP lhs, FP rhs, int maxUlpDiff) { + // Comparison with NaN should always be false. + // This way we can rule it out before getting into the ugly details + if (Catch::isnan(lhs) || Catch::isnan(rhs)) { + return false; + } + + auto lc = convert(lhs); + auto rc = convert(rhs); + + if ((lc.i < 0) != (rc.i < 0)) { + // Potentially we can have +0 and -0 + return lhs == rhs; + } + + auto ulpDiff = std::abs(lc.i - rc.i); + return ulpDiff <= maxUlpDiff; +} + +} + +namespace Catch { +namespace Matchers { +namespace Floating { + WithinAbsMatcher::WithinAbsMatcher(double target, double margin) + :m_target{ target }, m_margin{ margin } { + CATCH_ENFORCE(margin >= 0, "Invalid margin: " << margin << '.' + << " Margin has to be non-negative."); + } + + // Performs equivalent check of std::fabs(lhs - rhs) <= margin + // But without the subtraction to allow for INFINITY in comparison + bool WithinAbsMatcher::match(double const& matchee) const { + return (matchee + m_margin >= m_target) && (m_target + m_margin >= matchee); + } + + std::string WithinAbsMatcher::describe() const { + return "is within " + ::Catch::Detail::stringify(m_margin) + " of " + ::Catch::Detail::stringify(m_target); + } + + WithinUlpsMatcher::WithinUlpsMatcher(double target, int ulps, FloatingPointKind baseType) + :m_target{ target }, m_ulps{ ulps }, m_type{ baseType } { + CATCH_ENFORCE(ulps >= 0, "Invalid ULP setting: " << ulps << '.' + << " ULPs have to be non-negative."); + } + +#if defined(__clang__) +#pragma clang diagnostic push +// Clang <3.5 reports on the default branch in the switch below +#pragma clang diagnostic ignored "-Wunreachable-code" +#endif + + bool WithinUlpsMatcher::match(double const& matchee) const { + switch (m_type) { + case FloatingPointKind::Float: + return almostEqualUlps<float>(static_cast<float>(matchee), static_cast<float>(m_target), m_ulps); + case FloatingPointKind::Double: + return almostEqualUlps<double>(matchee, m_target, m_ulps); + default: + CATCH_INTERNAL_ERROR( "Unknown FloatingPointKind value" ); + } + } + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + + std::string WithinUlpsMatcher::describe() const { + return "is within " + Catch::to_string(m_ulps) + " ULPs of " + ::Catch::Detail::stringify(m_target) + ((m_type == FloatingPointKind::Float)? "f" : ""); + } + +}// namespace Floating + +Floating::WithinUlpsMatcher WithinULP(double target, int maxUlpDiff) { + return Floating::WithinUlpsMatcher(target, maxUlpDiff, Floating::FloatingPointKind::Double); +} + +Floating::WithinUlpsMatcher WithinULP(float target, int maxUlpDiff) { + return Floating::WithinUlpsMatcher(target, maxUlpDiff, Floating::FloatingPointKind::Float); +} + +Floating::WithinAbsMatcher WithinAbs(double target, double margin) { + return Floating::WithinAbsMatcher(target, margin); +} + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_floating.cpp +// start catch_matchers_generic.cpp + +std::string Catch::Matchers::Generic::Detail::finalizeDescription(const std::string& desc) { + if (desc.empty()) { + return "matches undescribed predicate"; + } else { + return "matches predicate: \"" + desc + '"'; + } +} +// end catch_matchers_generic.cpp +// start catch_matchers_string.cpp + +#include <regex> + +namespace Catch { +namespace Matchers { + + namespace StdString { + + CasedString::CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity ) + : m_caseSensitivity( caseSensitivity ), + m_str( adjustString( str ) ) + {} + std::string CasedString::adjustString( std::string const& str ) const { + return m_caseSensitivity == CaseSensitive::No + ? toLower( str ) + : str; + } + std::string CasedString::caseSensitivitySuffix() const { + return m_caseSensitivity == CaseSensitive::No + ? " (case insensitive)" + : std::string(); + } + + StringMatcherBase::StringMatcherBase( std::string const& operation, CasedString const& comparator ) + : m_comparator( comparator ), + m_operation( operation ) { + } + + std::string StringMatcherBase::describe() const { + std::string description; + description.reserve(5 + m_operation.size() + m_comparator.m_str.size() + + m_comparator.caseSensitivitySuffix().size()); + description += m_operation; + description += ": \""; + description += m_comparator.m_str; + description += "\""; + description += m_comparator.caseSensitivitySuffix(); + return description; + } + + EqualsMatcher::EqualsMatcher( CasedString const& comparator ) : StringMatcherBase( "equals", comparator ) {} + + bool EqualsMatcher::match( std::string const& source ) const { + return m_comparator.adjustString( source ) == m_comparator.m_str; + } + + ContainsMatcher::ContainsMatcher( CasedString const& comparator ) : StringMatcherBase( "contains", comparator ) {} + + bool ContainsMatcher::match( std::string const& source ) const { + return contains( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + StartsWithMatcher::StartsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "starts with", comparator ) {} + + bool StartsWithMatcher::match( std::string const& source ) const { + return startsWith( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + EndsWithMatcher::EndsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "ends with", comparator ) {} + + bool EndsWithMatcher::match( std::string const& source ) const { + return endsWith( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + RegexMatcher::RegexMatcher(std::string regex, CaseSensitive::Choice caseSensitivity): m_regex(std::move(regex)), m_caseSensitivity(caseSensitivity) {} + + bool RegexMatcher::match(std::string const& matchee) const { + auto flags = std::regex::ECMAScript; // ECMAScript is the default syntax option anyway + if (m_caseSensitivity == CaseSensitive::Choice::No) { + flags |= std::regex::icase; + } + auto reg = std::regex(m_regex, flags); + return std::regex_match(matchee, reg); + } + + std::string RegexMatcher::describe() const { + return "matches " + ::Catch::Detail::stringify(m_regex) + ((m_caseSensitivity == CaseSensitive::Choice::Yes)? " case sensitively" : " case insensitively"); + } + + } // namespace StdString + + StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::EqualsMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::ContainsMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::EndsWithMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::StartsWithMatcher( StdString::CasedString( str, caseSensitivity) ); + } + + StdString::RegexMatcher Matches(std::string const& regex, CaseSensitive::Choice caseSensitivity) { + return StdString::RegexMatcher(regex, caseSensitivity); + } + +} // namespace Matchers +} // namespace Catch +// end catch_matchers_string.cpp +// start catch_message.cpp + +// start catch_uncaught_exceptions.h + +namespace Catch { + bool uncaught_exceptions(); +} // end namespace Catch + +// end catch_uncaught_exceptions.h +#include <cassert> +#include <stack> + +namespace Catch { + + MessageInfo::MessageInfo( StringRef const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ) + : macroName( _macroName ), + lineInfo( _lineInfo ), + type( _type ), + sequence( ++globalCount ) + {} + + bool MessageInfo::operator==( MessageInfo const& other ) const { + return sequence == other.sequence; + } + + bool MessageInfo::operator<( MessageInfo const& other ) const { + return sequence < other.sequence; + } + + // This may need protecting if threading support is added + unsigned int MessageInfo::globalCount = 0; + + //////////////////////////////////////////////////////////////////////////// + + Catch::MessageBuilder::MessageBuilder( StringRef const& macroName, + SourceLineInfo const& lineInfo, + ResultWas::OfType type ) + :m_info(macroName, lineInfo, type) {} + + //////////////////////////////////////////////////////////////////////////// + + ScopedMessage::ScopedMessage( MessageBuilder const& builder ) + : m_info( builder.m_info ), m_moved() + { + m_info.message = builder.m_stream.str(); + getResultCapture().pushScopedMessage( m_info ); + } + + ScopedMessage::ScopedMessage( ScopedMessage&& old ) + : m_info( old.m_info ), m_moved() + { + old.m_moved = true; + } + + ScopedMessage::~ScopedMessage() { + if ( !uncaught_exceptions() && !m_moved ){ + getResultCapture().popScopedMessage(m_info); + } + } + + Capturer::Capturer( StringRef macroName, SourceLineInfo const& lineInfo, ResultWas::OfType resultType, StringRef names ) { + auto trimmed = [&] (size_t start, size_t end) { + while (names[start] == ',' || isspace(names[start])) { + ++start; + } + while (names[end] == ',' || isspace(names[end])) { + --end; + } + return names.substr(start, end - start + 1); + }; + + size_t start = 0; + std::stack<char> openings; + for (size_t pos = 0; pos < names.size(); ++pos) { + char c = names[pos]; + switch (c) { + case '[': + case '{': + case '(': + // It is basically impossible to disambiguate between + // comparison and start of template args in this context +// case '<': + openings.push(c); + break; + case ']': + case '}': + case ')': +// case '>': + openings.pop(); + break; + case ',': + if (start != pos && openings.size() == 0) { + m_messages.emplace_back(macroName, lineInfo, resultType); + m_messages.back().message = trimmed(start, pos); + m_messages.back().message += " := "; + start = pos; + } + } + } + assert(openings.size() == 0 && "Mismatched openings"); + m_messages.emplace_back(macroName, lineInfo, resultType); + m_messages.back().message = trimmed(start, names.size() - 1); + m_messages.back().message += " := "; + } + Capturer::~Capturer() { + if ( !uncaught_exceptions() ){ + assert( m_captured == m_messages.size() ); + for( size_t i = 0; i < m_captured; ++i ) + m_resultCapture.popScopedMessage( m_messages[i] ); + } + } + + void Capturer::captureValue( size_t index, std::string const& value ) { + assert( index < m_messages.size() ); + m_messages[index].message += value; + m_resultCapture.pushScopedMessage( m_messages[index] ); + m_captured++; + } + +} // end namespace Catch +// end catch_message.cpp +// start catch_output_redirect.cpp + +// start catch_output_redirect.h +#ifndef TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H +#define TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H + +#include <cstdio> +#include <iosfwd> +#include <string> + +namespace Catch { + + class RedirectedStream { + std::ostream& m_originalStream; + std::ostream& m_redirectionStream; + std::streambuf* m_prevBuf; + + public: + RedirectedStream( std::ostream& originalStream, std::ostream& redirectionStream ); + ~RedirectedStream(); + }; + + class RedirectedStdOut { + ReusableStringStream m_rss; + RedirectedStream m_cout; + public: + RedirectedStdOut(); + auto str() const -> std::string; + }; + + // StdErr has two constituent streams in C++, std::cerr and std::clog + // This means that we need to redirect 2 streams into 1 to keep proper + // order of writes + class RedirectedStdErr { + ReusableStringStream m_rss; + RedirectedStream m_cerr; + RedirectedStream m_clog; + public: + RedirectedStdErr(); + auto str() const -> std::string; + }; + + class RedirectedStreams { + public: + RedirectedStreams(RedirectedStreams const&) = delete; + RedirectedStreams& operator=(RedirectedStreams const&) = delete; + RedirectedStreams(RedirectedStreams&&) = delete; + RedirectedStreams& operator=(RedirectedStreams&&) = delete; + + RedirectedStreams(std::string& redirectedCout, std::string& redirectedCerr); + ~RedirectedStreams(); + private: + std::string& m_redirectedCout; + std::string& m_redirectedCerr; + RedirectedStdOut m_redirectedStdOut; + RedirectedStdErr m_redirectedStdErr; + }; + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + + // Windows's implementation of std::tmpfile is terrible (it tries + // to create a file inside system folder, thus requiring elevated + // privileges for the binary), so we have to use tmpnam(_s) and + // create the file ourselves there. + class TempFile { + public: + TempFile(TempFile const&) = delete; + TempFile& operator=(TempFile const&) = delete; + TempFile(TempFile&&) = delete; + TempFile& operator=(TempFile&&) = delete; + + TempFile(); + ~TempFile(); + + std::FILE* getFile(); + std::string getContents(); + + private: + std::FILE* m_file = nullptr; + #if defined(_MSC_VER) + char m_buffer[L_tmpnam] = { 0 }; + #endif + }; + + class OutputRedirect { + public: + OutputRedirect(OutputRedirect const&) = delete; + OutputRedirect& operator=(OutputRedirect const&) = delete; + OutputRedirect(OutputRedirect&&) = delete; + OutputRedirect& operator=(OutputRedirect&&) = delete; + + OutputRedirect(std::string& stdout_dest, std::string& stderr_dest); + ~OutputRedirect(); + + private: + int m_originalStdout = -1; + int m_originalStderr = -1; + TempFile m_stdoutFile; + TempFile m_stderrFile; + std::string& m_stdoutDest; + std::string& m_stderrDest; + }; + +#endif + +} // end namespace Catch + +#endif // TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H +// end catch_output_redirect.h +#include <cstdio> +#include <cstring> +#include <fstream> +#include <sstream> +#include <stdexcept> + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + #if defined(_MSC_VER) + #include <io.h> //_dup and _dup2 + #define dup _dup + #define dup2 _dup2 + #define fileno _fileno + #else + #include <unistd.h> // dup and dup2 + #endif +#endif + +namespace Catch { + + RedirectedStream::RedirectedStream( std::ostream& originalStream, std::ostream& redirectionStream ) + : m_originalStream( originalStream ), + m_redirectionStream( redirectionStream ), + m_prevBuf( m_originalStream.rdbuf() ) + { + m_originalStream.rdbuf( m_redirectionStream.rdbuf() ); + } + + RedirectedStream::~RedirectedStream() { + m_originalStream.rdbuf( m_prevBuf ); + } + + RedirectedStdOut::RedirectedStdOut() : m_cout( Catch::cout(), m_rss.get() ) {} + auto RedirectedStdOut::str() const -> std::string { return m_rss.str(); } + + RedirectedStdErr::RedirectedStdErr() + : m_cerr( Catch::cerr(), m_rss.get() ), + m_clog( Catch::clog(), m_rss.get() ) + {} + auto RedirectedStdErr::str() const -> std::string { return m_rss.str(); } + + RedirectedStreams::RedirectedStreams(std::string& redirectedCout, std::string& redirectedCerr) + : m_redirectedCout(redirectedCout), + m_redirectedCerr(redirectedCerr) + {} + + RedirectedStreams::~RedirectedStreams() { + m_redirectedCout += m_redirectedStdOut.str(); + m_redirectedCerr += m_redirectedStdErr.str(); + } + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + +#if defined(_MSC_VER) + TempFile::TempFile() { + if (tmpnam_s(m_buffer)) { + CATCH_RUNTIME_ERROR("Could not get a temp filename"); + } + if (fopen_s(&m_file, m_buffer, "w")) { + char buffer[100]; + if (strerror_s(buffer, errno)) { + CATCH_RUNTIME_ERROR("Could not translate errno to a string"); + } + CATCH_RUNTIME_ERROR("Coul dnot open the temp file: '" << m_buffer << "' because: " << buffer); + } + } +#else + TempFile::TempFile() { + m_file = std::tmpfile(); + if (!m_file) { + CATCH_RUNTIME_ERROR("Could not create a temp file."); + } + } + +#endif + + TempFile::~TempFile() { + // TBD: What to do about errors here? + std::fclose(m_file); + // We manually create the file on Windows only, on Linux + // it will be autodeleted +#if defined(_MSC_VER) + std::remove(m_buffer); +#endif + } + + FILE* TempFile::getFile() { + return m_file; + } + + std::string TempFile::getContents() { + std::stringstream sstr; + char buffer[100] = {}; + std::rewind(m_file); + while (std::fgets(buffer, sizeof(buffer), m_file)) { + sstr << buffer; + } + return sstr.str(); + } + + OutputRedirect::OutputRedirect(std::string& stdout_dest, std::string& stderr_dest) : + m_originalStdout(dup(1)), + m_originalStderr(dup(2)), + m_stdoutDest(stdout_dest), + m_stderrDest(stderr_dest) { + dup2(fileno(m_stdoutFile.getFile()), 1); + dup2(fileno(m_stderrFile.getFile()), 2); + } + + OutputRedirect::~OutputRedirect() { + Catch::cout() << std::flush; + fflush(stdout); + // Since we support overriding these streams, we flush cerr + // even though std::cerr is unbuffered + Catch::cerr() << std::flush; + Catch::clog() << std::flush; + fflush(stderr); + + dup2(m_originalStdout, 1); + dup2(m_originalStderr, 2); + + m_stdoutDest += m_stdoutFile.getContents(); + m_stderrDest += m_stderrFile.getContents(); + } + +#endif // CATCH_CONFIG_NEW_CAPTURE + +} // namespace Catch + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + #if defined(_MSC_VER) + #undef dup + #undef dup2 + #undef fileno + #endif +#endif +// end catch_output_redirect.cpp +// start catch_polyfills.cpp + +#include <cmath> + +namespace Catch { + +#if !defined(CATCH_CONFIG_POLYFILL_ISNAN) + bool isnan(float f) { + return std::isnan(f); + } + bool isnan(double d) { + return std::isnan(d); + } +#else + // For now we only use this for embarcadero + bool isnan(float f) { + return std::_isnan(f); + } + bool isnan(double d) { + return std::_isnan(d); + } +#endif + +} // end namespace Catch +// end catch_polyfills.cpp +// start catch_random_number_generator.cpp + +namespace Catch { + + std::mt19937& rng() { + static std::mt19937 s_rng; + return s_rng; + } + + void seedRng( IConfig const& config ) { + if( config.rngSeed() != 0 ) { + std::srand( config.rngSeed() ); + rng().seed( config.rngSeed() ); + } + } + + unsigned int rngSeed() { + return getCurrentContext().getConfig()->rngSeed(); + } +} +// end catch_random_number_generator.cpp +// start catch_registry_hub.cpp + +// start catch_test_case_registry_impl.h + +#include <vector> +#include <set> +#include <algorithm> +#include <ios> + +namespace Catch { + + class TestCase; + struct IConfig; + + std::vector<TestCase> sortTests( IConfig const& config, std::vector<TestCase> const& unsortedTestCases ); + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ); + + void enforceNoDuplicateTestCases( std::vector<TestCase> const& functions ); + + std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config ); + std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config ); + + class TestRegistry : public ITestCaseRegistry { + public: + virtual ~TestRegistry() = default; + + virtual void registerTest( TestCase const& testCase ); + + std::vector<TestCase> const& getAllTests() const override; + std::vector<TestCase> const& getAllTestsSorted( IConfig const& config ) const override; + + private: + std::vector<TestCase> m_functions; + mutable RunTests::InWhatOrder m_currentSortOrder = RunTests::InDeclarationOrder; + mutable std::vector<TestCase> m_sortedFunctions; + std::size_t m_unnamedCount = 0; + std::ios_base::Init m_ostreamInit; // Forces cout/ cerr to be initialised + }; + + /////////////////////////////////////////////////////////////////////////// + + class TestInvokerAsFunction : public ITestInvoker { + void(*m_testAsFunction)(); + public: + TestInvokerAsFunction( void(*testAsFunction)() ) noexcept; + + void invoke() const override; + }; + + std::string extractClassName( StringRef const& classOrQualifiedMethodName ); + + /////////////////////////////////////////////////////////////////////////// + +} // end namespace Catch + +// end catch_test_case_registry_impl.h +// start catch_reporter_registry.h + +#include <map> + +namespace Catch { + + class ReporterRegistry : public IReporterRegistry { + + public: + + ~ReporterRegistry() override; + + IStreamingReporterPtr create( std::string const& name, IConfigPtr const& config ) const override; + + void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ); + void registerListener( IReporterFactoryPtr const& factory ); + + FactoryMap const& getFactories() const override; + Listeners const& getListeners() const override; + + private: + FactoryMap m_factories; + Listeners m_listeners; + }; +} + +// end catch_reporter_registry.h +// start catch_tag_alias_registry.h + +// start catch_tag_alias.h + +#include <string> + +namespace Catch { + + struct TagAlias { + TagAlias(std::string const& _tag, SourceLineInfo _lineInfo); + + std::string tag; + SourceLineInfo lineInfo; + }; + +} // end namespace Catch + +// end catch_tag_alias.h +#include <map> + +namespace Catch { + + class TagAliasRegistry : public ITagAliasRegistry { + public: + ~TagAliasRegistry() override; + TagAlias const* find( std::string const& alias ) const override; + std::string expandAliases( std::string const& unexpandedTestSpec ) const override; + void add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ); + + private: + std::map<std::string, TagAlias> m_registry; + }; + +} // end namespace Catch + +// end catch_tag_alias_registry.h +// start catch_startup_exception_registry.h + +#include <vector> +#include <exception> + +namespace Catch { + + class StartupExceptionRegistry { + public: + void add(std::exception_ptr const& exception) noexcept; + std::vector<std::exception_ptr> const& getExceptions() const noexcept; + private: + std::vector<std::exception_ptr> m_exceptions; + }; + +} // end namespace Catch + +// end catch_startup_exception_registry.h +// start catch_singletons.hpp + +namespace Catch { + + struct ISingleton { + virtual ~ISingleton(); + }; + + void addSingleton( ISingleton* singleton ); + void cleanupSingletons(); + + template<typename SingletonImplT, typename InterfaceT = SingletonImplT, typename MutableInterfaceT = InterfaceT> + class Singleton : SingletonImplT, public ISingleton { + + static auto getInternal() -> Singleton* { + static Singleton* s_instance = nullptr; + if( !s_instance ) { + s_instance = new Singleton; + addSingleton( s_instance ); + } + return s_instance; + } + + public: + static auto get() -> InterfaceT const& { + return *getInternal(); + } + static auto getMutable() -> MutableInterfaceT& { + return *getInternal(); + } + }; + +} // namespace Catch + +// end catch_singletons.hpp +namespace Catch { + + namespace { + + class RegistryHub : public IRegistryHub, public IMutableRegistryHub, + private NonCopyable { + + public: // IRegistryHub + RegistryHub() = default; + IReporterRegistry const& getReporterRegistry() const override { + return m_reporterRegistry; + } + ITestCaseRegistry const& getTestCaseRegistry() const override { + return m_testCaseRegistry; + } + IExceptionTranslatorRegistry const& getExceptionTranslatorRegistry() const override { + return m_exceptionTranslatorRegistry; + } + ITagAliasRegistry const& getTagAliasRegistry() const override { + return m_tagAliasRegistry; + } + StartupExceptionRegistry const& getStartupExceptionRegistry() const override { + return m_exceptionRegistry; + } + + public: // IMutableRegistryHub + void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) override { + m_reporterRegistry.registerReporter( name, factory ); + } + void registerListener( IReporterFactoryPtr const& factory ) override { + m_reporterRegistry.registerListener( factory ); + } + void registerTest( TestCase const& testInfo ) override { + m_testCaseRegistry.registerTest( testInfo ); + } + void registerTranslator( const IExceptionTranslator* translator ) override { + m_exceptionTranslatorRegistry.registerTranslator( translator ); + } + void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) override { + m_tagAliasRegistry.add( alias, tag, lineInfo ); + } + void registerStartupException() noexcept override { + m_exceptionRegistry.add(std::current_exception()); + } + + private: + TestRegistry m_testCaseRegistry; + ReporterRegistry m_reporterRegistry; + ExceptionTranslatorRegistry m_exceptionTranslatorRegistry; + TagAliasRegistry m_tagAliasRegistry; + StartupExceptionRegistry m_exceptionRegistry; + }; + } + + using RegistryHubSingleton = Singleton<RegistryHub, IRegistryHub, IMutableRegistryHub>; + + IRegistryHub const& getRegistryHub() { + return RegistryHubSingleton::get(); + } + IMutableRegistryHub& getMutableRegistryHub() { + return RegistryHubSingleton::getMutable(); + } + void cleanUp() { + cleanupSingletons(); + cleanUpContext(); + } + std::string translateActiveException() { + return getRegistryHub().getExceptionTranslatorRegistry().translateActiveException(); + } + +} // end namespace Catch +// end catch_registry_hub.cpp +// start catch_reporter_registry.cpp + +namespace Catch { + + ReporterRegistry::~ReporterRegistry() = default; + + IStreamingReporterPtr ReporterRegistry::create( std::string const& name, IConfigPtr const& config ) const { + auto it = m_factories.find( name ); + if( it == m_factories.end() ) + return nullptr; + return it->second->create( ReporterConfig( config ) ); + } + + void ReporterRegistry::registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) { + m_factories.emplace(name, factory); + } + void ReporterRegistry::registerListener( IReporterFactoryPtr const& factory ) { + m_listeners.push_back( factory ); + } + + IReporterRegistry::FactoryMap const& ReporterRegistry::getFactories() const { + return m_factories; + } + IReporterRegistry::Listeners const& ReporterRegistry::getListeners() const { + return m_listeners; + } + +} +// end catch_reporter_registry.cpp +// start catch_result_type.cpp + +namespace Catch { + + bool isOk( ResultWas::OfType resultType ) { + return ( resultType & ResultWas::FailureBit ) == 0; + } + bool isJustInfo( int flags ) { + return flags == ResultWas::Info; + } + + ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ) { + return static_cast<ResultDisposition::Flags>( static_cast<int>( lhs ) | static_cast<int>( rhs ) ); + } + + bool shouldContinueOnFailure( int flags ) { return ( flags & ResultDisposition::ContinueOnFailure ) != 0; } + bool shouldSuppressFailure( int flags ) { return ( flags & ResultDisposition::SuppressFail ) != 0; } + +} // end namespace Catch +// end catch_result_type.cpp +// start catch_run_context.cpp + +#include <cassert> +#include <algorithm> +#include <sstream> + +namespace Catch { + + namespace Generators { + struct GeneratorTracker : TestCaseTracking::TrackerBase, IGeneratorTracker { + GeneratorBasePtr m_generator; + + GeneratorTracker( TestCaseTracking::NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : TrackerBase( nameAndLocation, ctx, parent ) + {} + ~GeneratorTracker(); + + static GeneratorTracker& acquire( TrackerContext& ctx, TestCaseTracking::NameAndLocation const& nameAndLocation ) { + std::shared_ptr<GeneratorTracker> tracker; + + ITracker& currentTracker = ctx.currentTracker(); + if( TestCaseTracking::ITrackerPtr childTracker = currentTracker.findChild( nameAndLocation ) ) { + assert( childTracker ); + assert( childTracker->isGeneratorTracker() ); + tracker = std::static_pointer_cast<GeneratorTracker>( childTracker ); + } + else { + tracker = std::make_shared<GeneratorTracker>( nameAndLocation, ctx, ¤tTracker ); + currentTracker.addChild( tracker ); + } + + if( !ctx.completedCycle() && !tracker->isComplete() ) { + tracker->open(); + } + + return *tracker; + } + + // TrackerBase interface + bool isGeneratorTracker() const override { return true; } + auto hasGenerator() const -> bool override { + return !!m_generator; + } + void close() override { + TrackerBase::close(); + // Generator interface only finds out if it has another item on atual move + if (m_runState == CompletedSuccessfully && m_generator->next()) { + m_children.clear(); + m_runState = Executing; + } + } + + // IGeneratorTracker interface + auto getGenerator() const -> GeneratorBasePtr const& override { + return m_generator; + } + void setGenerator( GeneratorBasePtr&& generator ) override { + m_generator = std::move( generator ); + } + }; + GeneratorTracker::~GeneratorTracker() {} + } + + RunContext::RunContext(IConfigPtr const& _config, IStreamingReporterPtr&& reporter) + : m_runInfo(_config->name()), + m_context(getCurrentMutableContext()), + m_config(_config), + m_reporter(std::move(reporter)), + m_lastAssertionInfo{ StringRef(), SourceLineInfo("",0), StringRef(), ResultDisposition::Normal }, + m_includeSuccessfulResults( m_config->includeSuccessfulResults() || m_reporter->getPreferences().shouldReportAllAssertions ) + { + m_context.setRunner(this); + m_context.setConfig(m_config); + m_context.setResultCapture(this); + m_reporter->testRunStarting(m_runInfo); + } + + RunContext::~RunContext() { + m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, aborting())); + } + + void RunContext::testGroupStarting(std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount) { + m_reporter->testGroupStarting(GroupInfo(testSpec, groupIndex, groupsCount)); + } + + void RunContext::testGroupEnded(std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount) { + m_reporter->testGroupEnded(TestGroupStats(GroupInfo(testSpec, groupIndex, groupsCount), totals, aborting())); + } + + Totals RunContext::runTest(TestCase const& testCase) { + Totals prevTotals = m_totals; + + std::string redirectedCout; + std::string redirectedCerr; + + auto const& testInfo = testCase.getTestCaseInfo(); + + m_reporter->testCaseStarting(testInfo); + + m_activeTestCase = &testCase; + + ITracker& rootTracker = m_trackerContext.startRun(); + assert(rootTracker.isSectionTracker()); + static_cast<SectionTracker&>(rootTracker).addInitialFilters(m_config->getSectionsToRun()); + do { + m_trackerContext.startCycle(); + m_testCaseTracker = &SectionTracker::acquire(m_trackerContext, TestCaseTracking::NameAndLocation(testInfo.name, testInfo.lineInfo)); + runCurrentTest(redirectedCout, redirectedCerr); + } while (!m_testCaseTracker->isSuccessfullyCompleted() && !aborting()); + + Totals deltaTotals = m_totals.delta(prevTotals); + if (testInfo.expectedToFail() && deltaTotals.testCases.passed > 0) { + deltaTotals.assertions.failed++; + deltaTotals.testCases.passed--; + deltaTotals.testCases.failed++; + } + m_totals.testCases += deltaTotals.testCases; + m_reporter->testCaseEnded(TestCaseStats(testInfo, + deltaTotals, + redirectedCout, + redirectedCerr, + aborting())); + + m_activeTestCase = nullptr; + m_testCaseTracker = nullptr; + + return deltaTotals; + } + + IConfigPtr RunContext::config() const { + return m_config; + } + + IStreamingReporter& RunContext::reporter() const { + return *m_reporter; + } + + void RunContext::assertionEnded(AssertionResult const & result) { + if (result.getResultType() == ResultWas::Ok) { + m_totals.assertions.passed++; + m_lastAssertionPassed = true; + } else if (!result.isOk()) { + m_lastAssertionPassed = false; + if( m_activeTestCase->getTestCaseInfo().okToFail() ) + m_totals.assertions.failedButOk++; + else + m_totals.assertions.failed++; + } + else { + m_lastAssertionPassed = true; + } + + // We have no use for the return value (whether messages should be cleared), because messages were made scoped + // and should be let to clear themselves out. + static_cast<void>(m_reporter->assertionEnded(AssertionStats(result, m_messages, m_totals))); + + if (result.getResultType() != ResultWas::Warning) + m_messageScopes.clear(); + + // Reset working state + resetAssertionInfo(); + m_lastResult = result; + } + void RunContext::resetAssertionInfo() { + m_lastAssertionInfo.macroName = StringRef(); + m_lastAssertionInfo.capturedExpression = "{Unknown expression after the reported line}"_sr; + } + + bool RunContext::sectionStarted(SectionInfo const & sectionInfo, Counts & assertions) { + ITracker& sectionTracker = SectionTracker::acquire(m_trackerContext, TestCaseTracking::NameAndLocation(sectionInfo.name, sectionInfo.lineInfo)); + if (!sectionTracker.isOpen()) + return false; + m_activeSections.push_back(§ionTracker); + + m_lastAssertionInfo.lineInfo = sectionInfo.lineInfo; + + m_reporter->sectionStarting(sectionInfo); + + assertions = m_totals.assertions; + + return true; + } + auto RunContext::acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& { + using namespace Generators; + GeneratorTracker& tracker = GeneratorTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( "generator", lineInfo ) ); + assert( tracker.isOpen() ); + m_lastAssertionInfo.lineInfo = lineInfo; + return tracker; + } + + bool RunContext::testForMissingAssertions(Counts& assertions) { + if (assertions.total() != 0) + return false; + if (!m_config->warnAboutMissingAssertions()) + return false; + if (m_trackerContext.currentTracker().hasChildren()) + return false; + m_totals.assertions.failed++; + assertions.failed++; + return true; + } + + void RunContext::sectionEnded(SectionEndInfo const & endInfo) { + Counts assertions = m_totals.assertions - endInfo.prevAssertions; + bool missingAssertions = testForMissingAssertions(assertions); + + if (!m_activeSections.empty()) { + m_activeSections.back()->close(); + m_activeSections.pop_back(); + } + + m_reporter->sectionEnded(SectionStats(endInfo.sectionInfo, assertions, endInfo.durationInSeconds, missingAssertions)); + m_messages.clear(); + m_messageScopes.clear(); + } + + void RunContext::sectionEndedEarly(SectionEndInfo const & endInfo) { + if (m_unfinishedSections.empty()) + m_activeSections.back()->fail(); + else + m_activeSections.back()->close(); + m_activeSections.pop_back(); + + m_unfinishedSections.push_back(endInfo); + } + void RunContext::benchmarkStarting( BenchmarkInfo const& info ) { + m_reporter->benchmarkStarting( info ); + } + void RunContext::benchmarkEnded( BenchmarkStats const& stats ) { + m_reporter->benchmarkEnded( stats ); + } + + void RunContext::pushScopedMessage(MessageInfo const & message) { + m_messages.push_back(message); + } + + void RunContext::popScopedMessage(MessageInfo const & message) { + m_messages.erase(std::remove(m_messages.begin(), m_messages.end(), message), m_messages.end()); + } + + void RunContext::emplaceUnscopedMessage( MessageBuilder const& builder ) { + m_messageScopes.emplace_back( builder ); + } + + std::string RunContext::getCurrentTestName() const { + return m_activeTestCase + ? m_activeTestCase->getTestCaseInfo().name + : std::string(); + } + + const AssertionResult * RunContext::getLastResult() const { + return &(*m_lastResult); + } + + void RunContext::exceptionEarlyReported() { + m_shouldReportUnexpected = false; + } + + void RunContext::handleFatalErrorCondition( StringRef message ) { + // First notify reporter that bad things happened + m_reporter->fatalErrorEncountered(message); + + // Don't rebuild the result -- the stringification itself can cause more fatal errors + // Instead, fake a result data. + AssertionResultData tempResult( ResultWas::FatalErrorCondition, { false } ); + tempResult.message = message; + AssertionResult result(m_lastAssertionInfo, tempResult); + + assertionEnded(result); + + handleUnfinishedSections(); + + // Recreate section for test case (as we will lose the one that was in scope) + auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name); + + Counts assertions; + assertions.failed = 1; + SectionStats testCaseSectionStats(testCaseSection, assertions, 0, false); + m_reporter->sectionEnded(testCaseSectionStats); + + auto const& testInfo = m_activeTestCase->getTestCaseInfo(); + + Totals deltaTotals; + deltaTotals.testCases.failed = 1; + deltaTotals.assertions.failed = 1; + m_reporter->testCaseEnded(TestCaseStats(testInfo, + deltaTotals, + std::string(), + std::string(), + false)); + m_totals.testCases.failed++; + testGroupEnded(std::string(), m_totals, 1, 1); + m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, false)); + } + + bool RunContext::lastAssertionPassed() { + return m_lastAssertionPassed; + } + + void RunContext::assertionPassed() { + m_lastAssertionPassed = true; + ++m_totals.assertions.passed; + resetAssertionInfo(); + m_messageScopes.clear(); + } + + bool RunContext::aborting() const { + return m_totals.assertions.failed >= static_cast<std::size_t>(m_config->abortAfter()); + } + + void RunContext::runCurrentTest(std::string & redirectedCout, std::string & redirectedCerr) { + auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name); + m_reporter->sectionStarting(testCaseSection); + Counts prevAssertions = m_totals.assertions; + double duration = 0; + m_shouldReportUnexpected = true; + m_lastAssertionInfo = { "TEST_CASE"_sr, testCaseInfo.lineInfo, StringRef(), ResultDisposition::Normal }; + + seedRng(*m_config); + + Timer timer; + CATCH_TRY { + if (m_reporter->getPreferences().shouldRedirectStdOut) { +#if !defined(CATCH_CONFIG_EXPERIMENTAL_REDIRECT) + RedirectedStreams redirectedStreams(redirectedCout, redirectedCerr); + + timer.start(); + invokeActiveTestCase(); +#else + OutputRedirect r(redirectedCout, redirectedCerr); + timer.start(); + invokeActiveTestCase(); +#endif + } else { + timer.start(); + invokeActiveTestCase(); + } + duration = timer.getElapsedSeconds(); + } CATCH_CATCH_ANON (TestFailureException&) { + // This just means the test was aborted due to failure + } CATCH_CATCH_ALL { + // Under CATCH_CONFIG_FAST_COMPILE, unexpected exceptions under REQUIRE assertions + // are reported without translation at the point of origin. + if( m_shouldReportUnexpected ) { + AssertionReaction dummyReaction; + handleUnexpectedInflightException( m_lastAssertionInfo, translateActiveException(), dummyReaction ); + } + } + Counts assertions = m_totals.assertions - prevAssertions; + bool missingAssertions = testForMissingAssertions(assertions); + + m_testCaseTracker->close(); + handleUnfinishedSections(); + m_messages.clear(); + m_messageScopes.clear(); + + SectionStats testCaseSectionStats(testCaseSection, assertions, duration, missingAssertions); + m_reporter->sectionEnded(testCaseSectionStats); + } + + void RunContext::invokeActiveTestCase() { + FatalConditionHandler fatalConditionHandler; // Handle signals + m_activeTestCase->invoke(); + fatalConditionHandler.reset(); + } + + void RunContext::handleUnfinishedSections() { + // If sections ended prematurely due to an exception we stored their + // infos here so we can tear them down outside the unwind process. + for (auto it = m_unfinishedSections.rbegin(), + itEnd = m_unfinishedSections.rend(); + it != itEnd; + ++it) + sectionEnded(*it); + m_unfinishedSections.clear(); + } + + void RunContext::handleExpr( + AssertionInfo const& info, + ITransientExpression const& expr, + AssertionReaction& reaction + ) { + m_reporter->assertionStarting( info ); + + bool negated = isFalseTest( info.resultDisposition ); + bool result = expr.getResult() != negated; + + if( result ) { + if (!m_includeSuccessfulResults) { + assertionPassed(); + } + else { + reportExpr(info, ResultWas::Ok, &expr, negated); + } + } + else { + reportExpr(info, ResultWas::ExpressionFailed, &expr, negated ); + populateReaction( reaction ); + } + } + void RunContext::reportExpr( + AssertionInfo const &info, + ResultWas::OfType resultType, + ITransientExpression const *expr, + bool negated ) { + + m_lastAssertionInfo = info; + AssertionResultData data( resultType, LazyExpression( negated ) ); + + AssertionResult assertionResult{ info, data }; + assertionResult.m_resultData.lazyExpression.m_transientExpression = expr; + + assertionEnded( assertionResult ); + } + + void RunContext::handleMessage( + AssertionInfo const& info, + ResultWas::OfType resultType, + StringRef const& message, + AssertionReaction& reaction + ) { + m_reporter->assertionStarting( info ); + + m_lastAssertionInfo = info; + + AssertionResultData data( resultType, LazyExpression( false ) ); + data.message = message; + AssertionResult assertionResult{ m_lastAssertionInfo, data }; + assertionEnded( assertionResult ); + if( !assertionResult.isOk() ) + populateReaction( reaction ); + } + void RunContext::handleUnexpectedExceptionNotThrown( + AssertionInfo const& info, + AssertionReaction& reaction + ) { + handleNonExpr(info, Catch::ResultWas::DidntThrowException, reaction); + } + + void RunContext::handleUnexpectedInflightException( + AssertionInfo const& info, + std::string const& message, + AssertionReaction& reaction + ) { + m_lastAssertionInfo = info; + + AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) ); + data.message = message; + AssertionResult assertionResult{ info, data }; + assertionEnded( assertionResult ); + populateReaction( reaction ); + } + + void RunContext::populateReaction( AssertionReaction& reaction ) { + reaction.shouldDebugBreak = m_config->shouldDebugBreak(); + reaction.shouldThrow = aborting() || (m_lastAssertionInfo.resultDisposition & ResultDisposition::Normal); + } + + void RunContext::handleIncomplete( + AssertionInfo const& info + ) { + m_lastAssertionInfo = info; + + AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) ); + data.message = "Exception translation was disabled by CATCH_CONFIG_FAST_COMPILE"; + AssertionResult assertionResult{ info, data }; + assertionEnded( assertionResult ); + } + void RunContext::handleNonExpr( + AssertionInfo const &info, + ResultWas::OfType resultType, + AssertionReaction &reaction + ) { + m_lastAssertionInfo = info; + + AssertionResultData data( resultType, LazyExpression( false ) ); + AssertionResult assertionResult{ info, data }; + assertionEnded( assertionResult ); + + if( !assertionResult.isOk() ) + populateReaction( reaction ); + } + + IResultCapture& getResultCapture() { + if (auto* capture = getCurrentContext().getResultCapture()) + return *capture; + else + CATCH_INTERNAL_ERROR("No result capture instance"); + } +} +// end catch_run_context.cpp +// start catch_section.cpp + +namespace Catch { + + Section::Section( SectionInfo const& info ) + : m_info( info ), + m_sectionIncluded( getResultCapture().sectionStarted( m_info, m_assertions ) ) + { + m_timer.start(); + } + + Section::~Section() { + if( m_sectionIncluded ) { + SectionEndInfo endInfo{ m_info, m_assertions, m_timer.getElapsedSeconds() }; + if( uncaught_exceptions() ) + getResultCapture().sectionEndedEarly( endInfo ); + else + getResultCapture().sectionEnded( endInfo ); + } + } + + // This indicates whether the section should be executed or not + Section::operator bool() const { + return m_sectionIncluded; + } + +} // end namespace Catch +// end catch_section.cpp +// start catch_section_info.cpp + +namespace Catch { + + SectionInfo::SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name ) + : name( _name ), + lineInfo( _lineInfo ) + {} + +} // end namespace Catch +// end catch_section_info.cpp +// start catch_session.cpp + +// start catch_session.h + +#include <memory> + +namespace Catch { + + class Session : NonCopyable { + public: + + Session(); + ~Session() override; + + void showHelp() const; + void libIdentify(); + + int applyCommandLine( int argc, char const * const * argv ); + #if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(UNICODE) + int applyCommandLine( int argc, wchar_t const * const * argv ); + #endif + + void useConfigData( ConfigData const& configData ); + + template<typename CharT> + int run(int argc, CharT const * const argv[]) { + if (m_startupExceptions) + return 1; + int returnCode = applyCommandLine(argc, argv); + if (returnCode == 0) + returnCode = run(); + return returnCode; + } + + int run(); + + clara::Parser const& cli() const; + void cli( clara::Parser const& newParser ); + ConfigData& configData(); + Config& config(); + private: + int runInternal(); + + clara::Parser m_cli; + ConfigData m_configData; + std::shared_ptr<Config> m_config; + bool m_startupExceptions = false; + }; + +} // end namespace Catch + +// end catch_session.h +// start catch_version.h + +#include <iosfwd> + +namespace Catch { + + // Versioning information + struct Version { + Version( Version const& ) = delete; + Version& operator=( Version const& ) = delete; + Version( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _patchNumber, + char const * const _branchName, + unsigned int _buildNumber ); + + unsigned int const majorVersion; + unsigned int const minorVersion; + unsigned int const patchNumber; + + // buildNumber is only used if branchName is not null + char const * const branchName; + unsigned int const buildNumber; + + friend std::ostream& operator << ( std::ostream& os, Version const& version ); + }; + + Version const& libraryVersion(); +} + +// end catch_version.h +#include <cstdlib> +#include <iomanip> + +namespace Catch { + + namespace { + const int MaxExitCode = 255; + + IStreamingReporterPtr createReporter(std::string const& reporterName, IConfigPtr const& config) { + auto reporter = Catch::getRegistryHub().getReporterRegistry().create(reporterName, config); + CATCH_ENFORCE(reporter, "No reporter registered with name: '" << reporterName << "'"); + + return reporter; + } + + IStreamingReporterPtr makeReporter(std::shared_ptr<Config> const& config) { + if (Catch::getRegistryHub().getReporterRegistry().getListeners().empty()) { + return createReporter(config->getReporterName(), config); + } + + // On older platforms, returning std::unique_ptr<ListeningReporter> + // when the return type is std::unique_ptr<IStreamingReporter> + // doesn't compile without a std::move call. However, this causes + // a warning on newer platforms. Thus, we have to work around + // it a bit and downcast the pointer manually. + auto ret = std::unique_ptr<IStreamingReporter>(new ListeningReporter); + auto& multi = static_cast<ListeningReporter&>(*ret); + auto const& listeners = Catch::getRegistryHub().getReporterRegistry().getListeners(); + for (auto const& listener : listeners) { + multi.addListener(listener->create(Catch::ReporterConfig(config))); + } + multi.addReporter(createReporter(config->getReporterName(), config)); + return ret; + } + + Catch::Totals runTests(std::shared_ptr<Config> const& config) { + auto reporter = makeReporter(config); + + RunContext context(config, std::move(reporter)); + + Totals totals; + + context.testGroupStarting(config->name(), 1, 1); + + TestSpec testSpec = config->testSpec(); + + auto const& allTestCases = getAllTestCasesSorted(*config); + for (auto const& testCase : allTestCases) { + if (!context.aborting() && matchTest(testCase, testSpec, *config)) + totals += context.runTest(testCase); + else + context.reporter().skipTest(testCase); + } + + if (config->warnAboutNoTests() && totals.testCases.total() == 0) { + ReusableStringStream testConfig; + + bool first = true; + for (const auto& input : config->getTestsOrTags()) { + if (!first) { testConfig << ' '; } + first = false; + testConfig << input; + } + + context.reporter().noMatchingTestCases(testConfig.str()); + totals.error = -1; + } + + context.testGroupEnded(config->name(), totals, 1, 1); + return totals; + } + + void applyFilenamesAsTags(Catch::IConfig const& config) { + auto& tests = const_cast<std::vector<TestCase>&>(getAllTestCasesSorted(config)); + for (auto& testCase : tests) { + auto tags = testCase.tags; + + std::string filename = testCase.lineInfo.file; + auto lastSlash = filename.find_last_of("\\/"); + if (lastSlash != std::string::npos) { + filename.erase(0, lastSlash); + filename[0] = '#'; + } + + auto lastDot = filename.find_last_of('.'); + if (lastDot != std::string::npos) { + filename.erase(lastDot); + } + + tags.push_back(std::move(filename)); + setTags(testCase, tags); + } + } + + } // anon namespace + + Session::Session() { + static bool alreadyInstantiated = false; + if( alreadyInstantiated ) { + CATCH_TRY { CATCH_INTERNAL_ERROR( "Only one instance of Catch::Session can ever be used" ); } + CATCH_CATCH_ALL { getMutableRegistryHub().registerStartupException(); } + } + + // There cannot be exceptions at startup in no-exception mode. +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + const auto& exceptions = getRegistryHub().getStartupExceptionRegistry().getExceptions(); + if ( !exceptions.empty() ) { + m_startupExceptions = true; + Colour colourGuard( Colour::Red ); + Catch::cerr() << "Errors occurred during startup!" << '\n'; + // iterate over all exceptions and notify user + for ( const auto& ex_ptr : exceptions ) { + try { + std::rethrow_exception(ex_ptr); + } catch ( std::exception const& ex ) { + Catch::cerr() << Column( ex.what() ).indent(2) << '\n'; + } + } + } +#endif + + alreadyInstantiated = true; + m_cli = makeCommandLineParser( m_configData ); + } + Session::~Session() { + Catch::cleanUp(); + } + + void Session::showHelp() const { + Catch::cout() + << "\nCatch v" << libraryVersion() << "\n" + << m_cli << std::endl + << "For more detailed usage please see the project docs\n" << std::endl; + } + void Session::libIdentify() { + Catch::cout() + << std::left << std::setw(16) << "description: " << "A Catch test executable\n" + << std::left << std::setw(16) << "category: " << "testframework\n" + << std::left << std::setw(16) << "framework: " << "Catch Test\n" + << std::left << std::setw(16) << "version: " << libraryVersion() << std::endl; + } + + int Session::applyCommandLine( int argc, char const * const * argv ) { + if( m_startupExceptions ) + return 1; + + auto result = m_cli.parse( clara::Args( argc, argv ) ); + if( !result ) { + config(); + getCurrentMutableContext().setConfig(m_config); + Catch::cerr() + << Colour( Colour::Red ) + << "\nError(s) in input:\n" + << Column( result.errorMessage() ).indent( 2 ) + << "\n\n"; + Catch::cerr() << "Run with -? for usage\n" << std::endl; + return MaxExitCode; + } + + if( m_configData.showHelp ) + showHelp(); + if( m_configData.libIdentify ) + libIdentify(); + m_config.reset(); + return 0; + } + +#if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(UNICODE) + int Session::applyCommandLine( int argc, wchar_t const * const * argv ) { + + char **utf8Argv = new char *[ argc ]; + + for ( int i = 0; i < argc; ++i ) { + int bufSize = WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, NULL, 0, NULL, NULL ); + + utf8Argv[ i ] = new char[ bufSize ]; + + WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, utf8Argv[i], bufSize, NULL, NULL ); + } + + int returnCode = applyCommandLine( argc, utf8Argv ); + + for ( int i = 0; i < argc; ++i ) + delete [] utf8Argv[ i ]; + + delete [] utf8Argv; + + return returnCode; + } +#endif + + void Session::useConfigData( ConfigData const& configData ) { + m_configData = configData; + m_config.reset(); + } + + int Session::run() { + if( ( m_configData.waitForKeypress & WaitForKeypress::BeforeStart ) != 0 ) { + Catch::cout() << "...waiting for enter/ return before starting" << std::endl; + static_cast<void>(std::getchar()); + } + int exitCode = runInternal(); + if( ( m_configData.waitForKeypress & WaitForKeypress::BeforeExit ) != 0 ) { + Catch::cout() << "...waiting for enter/ return before exiting, with code: " << exitCode << std::endl; + static_cast<void>(std::getchar()); + } + return exitCode; + } + + clara::Parser const& Session::cli() const { + return m_cli; + } + void Session::cli( clara::Parser const& newParser ) { + m_cli = newParser; + } + ConfigData& Session::configData() { + return m_configData; + } + Config& Session::config() { + if( !m_config ) + m_config = std::make_shared<Config>( m_configData ); + return *m_config; + } + + int Session::runInternal() { + if( m_startupExceptions ) + return 1; + + if (m_configData.showHelp || m_configData.libIdentify) { + return 0; + } + + CATCH_TRY { + config(); // Force config to be constructed + + seedRng( *m_config ); + + if( m_configData.filenamesAsTags ) + applyFilenamesAsTags( *m_config ); + + // Handle list request + if( Option<std::size_t> listed = list( m_config ) ) + return static_cast<int>( *listed ); + + auto totals = runTests( m_config ); + // Note that on unices only the lower 8 bits are usually used, clamping + // the return value to 255 prevents false negative when some multiple + // of 256 tests has failed + return (std::min) (MaxExitCode, (std::max) (totals.error, static_cast<int>(totals.assertions.failed))); + } +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + catch( std::exception& ex ) { + Catch::cerr() << ex.what() << std::endl; + return MaxExitCode; + } +#endif + } + +} // end namespace Catch +// end catch_session.cpp +// start catch_singletons.cpp + +#include <vector> + +namespace Catch { + + namespace { + static auto getSingletons() -> std::vector<ISingleton*>*& { + static std::vector<ISingleton*>* g_singletons = nullptr; + if( !g_singletons ) + g_singletons = new std::vector<ISingleton*>(); + return g_singletons; + } + } + + ISingleton::~ISingleton() {} + + void addSingleton(ISingleton* singleton ) { + getSingletons()->push_back( singleton ); + } + void cleanupSingletons() { + auto& singletons = getSingletons(); + for( auto singleton : *singletons ) + delete singleton; + delete singletons; + singletons = nullptr; + } + +} // namespace Catch +// end catch_singletons.cpp +// start catch_startup_exception_registry.cpp + +namespace Catch { +void StartupExceptionRegistry::add( std::exception_ptr const& exception ) noexcept { + CATCH_TRY { + m_exceptions.push_back(exception); + } CATCH_CATCH_ALL { + // If we run out of memory during start-up there's really not a lot more we can do about it + std::terminate(); + } + } + + std::vector<std::exception_ptr> const& StartupExceptionRegistry::getExceptions() const noexcept { + return m_exceptions; + } + +} // end namespace Catch +// end catch_startup_exception_registry.cpp +// start catch_stream.cpp + +#include <cstdio> +#include <iostream> +#include <fstream> +#include <sstream> +#include <vector> +#include <memory> + +namespace Catch { + + Catch::IStream::~IStream() = default; + + namespace detail { namespace { + template<typename WriterF, std::size_t bufferSize=256> + class StreamBufImpl : public std::streambuf { + char data[bufferSize]; + WriterF m_writer; + + public: + StreamBufImpl() { + setp( data, data + sizeof(data) ); + } + + ~StreamBufImpl() noexcept { + StreamBufImpl::sync(); + } + + private: + int overflow( int c ) override { + sync(); + + if( c != EOF ) { + if( pbase() == epptr() ) + m_writer( std::string( 1, static_cast<char>( c ) ) ); + else + sputc( static_cast<char>( c ) ); + } + return 0; + } + + int sync() override { + if( pbase() != pptr() ) { + m_writer( std::string( pbase(), static_cast<std::string::size_type>( pptr() - pbase() ) ) ); + setp( pbase(), epptr() ); + } + return 0; + } + }; + + /////////////////////////////////////////////////////////////////////////// + + struct OutputDebugWriter { + + void operator()( std::string const&str ) { + writeToDebugConsole( str ); + } + }; + + /////////////////////////////////////////////////////////////////////////// + + class FileStream : public IStream { + mutable std::ofstream m_ofs; + public: + FileStream( StringRef filename ) { + m_ofs.open( filename.c_str() ); + CATCH_ENFORCE( !m_ofs.fail(), "Unable to open file: '" << filename << "'" ); + } + ~FileStream() override = default; + public: // IStream + std::ostream& stream() const override { + return m_ofs; + } + }; + + /////////////////////////////////////////////////////////////////////////// + + class CoutStream : public IStream { + mutable std::ostream m_os; + public: + // Store the streambuf from cout up-front because + // cout may get redirected when running tests + CoutStream() : m_os( Catch::cout().rdbuf() ) {} + ~CoutStream() override = default; + + public: // IStream + std::ostream& stream() const override { return m_os; } + }; + + /////////////////////////////////////////////////////////////////////////// + + class DebugOutStream : public IStream { + std::unique_ptr<StreamBufImpl<OutputDebugWriter>> m_streamBuf; + mutable std::ostream m_os; + public: + DebugOutStream() + : m_streamBuf( new StreamBufImpl<OutputDebugWriter>() ), + m_os( m_streamBuf.get() ) + {} + + ~DebugOutStream() override = default; + + public: // IStream + std::ostream& stream() const override { return m_os; } + }; + + }} // namespace anon::detail + + /////////////////////////////////////////////////////////////////////////// + + auto makeStream( StringRef const &filename ) -> IStream const* { + if( filename.empty() ) + return new detail::CoutStream(); + else if( filename[0] == '%' ) { + if( filename == "%debug" ) + return new detail::DebugOutStream(); + else + CATCH_ERROR( "Unrecognised stream: '" << filename << "'" ); + } + else + return new detail::FileStream( filename ); + } + + // This class encapsulates the idea of a pool of ostringstreams that can be reused. + struct StringStreams { + std::vector<std::unique_ptr<std::ostringstream>> m_streams; + std::vector<std::size_t> m_unused; + std::ostringstream m_referenceStream; // Used for copy state/ flags from + + auto add() -> std::size_t { + if( m_unused.empty() ) { + m_streams.push_back( std::unique_ptr<std::ostringstream>( new std::ostringstream ) ); + return m_streams.size()-1; + } + else { + auto index = m_unused.back(); + m_unused.pop_back(); + return index; + } + } + + void release( std::size_t index ) { + m_streams[index]->copyfmt( m_referenceStream ); // Restore initial flags and other state + m_unused.push_back(index); + } + }; + + ReusableStringStream::ReusableStringStream() + : m_index( Singleton<StringStreams>::getMutable().add() ), + m_oss( Singleton<StringStreams>::getMutable().m_streams[m_index].get() ) + {} + + ReusableStringStream::~ReusableStringStream() { + static_cast<std::ostringstream*>( m_oss )->str(""); + m_oss->clear(); + Singleton<StringStreams>::getMutable().release( m_index ); + } + + auto ReusableStringStream::str() const -> std::string { + return static_cast<std::ostringstream*>( m_oss )->str(); + } + + /////////////////////////////////////////////////////////////////////////// + +#ifndef CATCH_CONFIG_NOSTDOUT // If you #define this you must implement these functions + std::ostream& cout() { return std::cout; } + std::ostream& cerr() { return std::cerr; } + std::ostream& clog() { return std::clog; } +#endif +} +// end catch_stream.cpp +// start catch_string_manip.cpp + +#include <algorithm> +#include <ostream> +#include <cstring> +#include <cctype> + +namespace Catch { + + namespace { + char toLowerCh(char c) { + return static_cast<char>( std::tolower( c ) ); + } + } + + bool startsWith( std::string const& s, std::string const& prefix ) { + return s.size() >= prefix.size() && std::equal(prefix.begin(), prefix.end(), s.begin()); + } + bool startsWith( std::string const& s, char prefix ) { + return !s.empty() && s[0] == prefix; + } + bool endsWith( std::string const& s, std::string const& suffix ) { + return s.size() >= suffix.size() && std::equal(suffix.rbegin(), suffix.rend(), s.rbegin()); + } + bool endsWith( std::string const& s, char suffix ) { + return !s.empty() && s[s.size()-1] == suffix; + } + bool contains( std::string const& s, std::string const& infix ) { + return s.find( infix ) != std::string::npos; + } + void toLowerInPlace( std::string& s ) { + std::transform( s.begin(), s.end(), s.begin(), toLowerCh ); + } + std::string toLower( std::string const& s ) { + std::string lc = s; + toLowerInPlace( lc ); + return lc; + } + std::string trim( std::string const& str ) { + static char const* whitespaceChars = "\n\r\t "; + std::string::size_type start = str.find_first_not_of( whitespaceChars ); + std::string::size_type end = str.find_last_not_of( whitespaceChars ); + + return start != std::string::npos ? str.substr( start, 1+end-start ) : std::string(); + } + + bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ) { + bool replaced = false; + std::size_t i = str.find( replaceThis ); + while( i != std::string::npos ) { + replaced = true; + str = str.substr( 0, i ) + withThis + str.substr( i+replaceThis.size() ); + if( i < str.size()-withThis.size() ) + i = str.find( replaceThis, i+withThis.size() ); + else + i = std::string::npos; + } + return replaced; + } + + pluralise::pluralise( std::size_t count, std::string const& label ) + : m_count( count ), + m_label( label ) + {} + + std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ) { + os << pluraliser.m_count << ' ' << pluraliser.m_label; + if( pluraliser.m_count != 1 ) + os << 's'; + return os; + } + +} +// end catch_string_manip.cpp +// start catch_stringref.cpp + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +#endif + +#include <ostream> +#include <cstring> +#include <cstdint> + +namespace { + const uint32_t byte_2_lead = 0xC0; + const uint32_t byte_3_lead = 0xE0; + const uint32_t byte_4_lead = 0xF0; +} + +namespace Catch { + StringRef::StringRef( char const* rawChars ) noexcept + : StringRef( rawChars, static_cast<StringRef::size_type>(std::strlen(rawChars) ) ) + {} + + StringRef::operator std::string() const { + return std::string( m_start, m_size ); + } + + void StringRef::swap( StringRef& other ) noexcept { + std::swap( m_start, other.m_start ); + std::swap( m_size, other.m_size ); + std::swap( m_data, other.m_data ); + } + + auto StringRef::c_str() const -> char const* { + if( isSubstring() ) + const_cast<StringRef*>( this )->takeOwnership(); + return m_start; + } + auto StringRef::currentData() const noexcept -> char const* { + return m_start; + } + + auto StringRef::isOwned() const noexcept -> bool { + return m_data != nullptr; + } + auto StringRef::isSubstring() const noexcept -> bool { + return m_start[m_size] != '\0'; + } + + void StringRef::takeOwnership() { + if( !isOwned() ) { + m_data = new char[m_size+1]; + memcpy( m_data, m_start, m_size ); + m_data[m_size] = '\0'; + m_start = m_data; + } + } + auto StringRef::substr( size_type start, size_type size ) const noexcept -> StringRef { + if( start < m_size ) + return StringRef( m_start+start, size ); + else + return StringRef(); + } + auto StringRef::operator == ( StringRef const& other ) const noexcept -> bool { + return + size() == other.size() && + (std::strncmp( m_start, other.m_start, size() ) == 0); + } + auto StringRef::operator != ( StringRef const& other ) const noexcept -> bool { + return !operator==( other ); + } + + auto StringRef::operator[](size_type index) const noexcept -> char { + return m_start[index]; + } + + auto StringRef::numberOfCharacters() const noexcept -> size_type { + size_type noChars = m_size; + // Make adjustments for uft encodings + for( size_type i=0; i < m_size; ++i ) { + char c = m_start[i]; + if( ( c & byte_2_lead ) == byte_2_lead ) { + noChars--; + if (( c & byte_3_lead ) == byte_3_lead ) + noChars--; + if( ( c & byte_4_lead ) == byte_4_lead ) + noChars--; + } + } + return noChars; + } + + auto operator + ( StringRef const& lhs, StringRef const& rhs ) -> std::string { + std::string str; + str.reserve( lhs.size() + rhs.size() ); + str += lhs; + str += rhs; + return str; + } + auto operator + ( StringRef const& lhs, const char* rhs ) -> std::string { + return std::string( lhs ) + std::string( rhs ); + } + auto operator + ( char const* lhs, StringRef const& rhs ) -> std::string { + return std::string( lhs ) + std::string( rhs ); + } + + auto operator << ( std::ostream& os, StringRef const& str ) -> std::ostream& { + return os.write(str.currentData(), str.size()); + } + + auto operator+=( std::string& lhs, StringRef const& rhs ) -> std::string& { + lhs.append(rhs.currentData(), rhs.size()); + return lhs; + } + +} // namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif +// end catch_stringref.cpp +// start catch_tag_alias.cpp + +namespace Catch { + TagAlias::TagAlias(std::string const & _tag, SourceLineInfo _lineInfo): tag(_tag), lineInfo(_lineInfo) {} +} +// end catch_tag_alias.cpp +// start catch_tag_alias_autoregistrar.cpp + +namespace Catch { + + RegistrarForTagAliases::RegistrarForTagAliases(char const* alias, char const* tag, SourceLineInfo const& lineInfo) { + CATCH_TRY { + getMutableRegistryHub().registerTagAlias(alias, tag, lineInfo); + } CATCH_CATCH_ALL { + // Do not throw when constructing global objects, instead register the exception to be processed later + getMutableRegistryHub().registerStartupException(); + } + } + +} +// end catch_tag_alias_autoregistrar.cpp +// start catch_tag_alias_registry.cpp + +#include <sstream> + +namespace Catch { + + TagAliasRegistry::~TagAliasRegistry() {} + + TagAlias const* TagAliasRegistry::find( std::string const& alias ) const { + auto it = m_registry.find( alias ); + if( it != m_registry.end() ) + return &(it->second); + else + return nullptr; + } + + std::string TagAliasRegistry::expandAliases( std::string const& unexpandedTestSpec ) const { + std::string expandedTestSpec = unexpandedTestSpec; + for( auto const& registryKvp : m_registry ) { + std::size_t pos = expandedTestSpec.find( registryKvp.first ); + if( pos != std::string::npos ) { + expandedTestSpec = expandedTestSpec.substr( 0, pos ) + + registryKvp.second.tag + + expandedTestSpec.substr( pos + registryKvp.first.size() ); + } + } + return expandedTestSpec; + } + + void TagAliasRegistry::add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) { + CATCH_ENFORCE( startsWith(alias, "[@") && endsWith(alias, ']'), + "error: tag alias, '" << alias << "' is not of the form [@alias name].\n" << lineInfo ); + + CATCH_ENFORCE( m_registry.insert(std::make_pair(alias, TagAlias(tag, lineInfo))).second, + "error: tag alias, '" << alias << "' already registered.\n" + << "\tFirst seen at: " << find(alias)->lineInfo << "\n" + << "\tRedefined at: " << lineInfo ); + } + + ITagAliasRegistry::~ITagAliasRegistry() {} + + ITagAliasRegistry const& ITagAliasRegistry::get() { + return getRegistryHub().getTagAliasRegistry(); + } + +} // end namespace Catch +// end catch_tag_alias_registry.cpp +// start catch_test_case_info.cpp + +#include <cctype> +#include <exception> +#include <algorithm> +#include <sstream> + +namespace Catch { + + namespace { + TestCaseInfo::SpecialProperties parseSpecialTag( std::string const& tag ) { + if( startsWith( tag, '.' ) || + tag == "!hide" ) + return TestCaseInfo::IsHidden; + else if( tag == "!throws" ) + return TestCaseInfo::Throws; + else if( tag == "!shouldfail" ) + return TestCaseInfo::ShouldFail; + else if( tag == "!mayfail" ) + return TestCaseInfo::MayFail; + else if( tag == "!nonportable" ) + return TestCaseInfo::NonPortable; + else if( tag == "!benchmark" ) + return static_cast<TestCaseInfo::SpecialProperties>( TestCaseInfo::Benchmark | TestCaseInfo::IsHidden ); + else + return TestCaseInfo::None; + } + bool isReservedTag( std::string const& tag ) { + return parseSpecialTag( tag ) == TestCaseInfo::None && tag.size() > 0 && !std::isalnum( static_cast<unsigned char>(tag[0]) ); + } + void enforceNotReservedTag( std::string const& tag, SourceLineInfo const& _lineInfo ) { + CATCH_ENFORCE( !isReservedTag(tag), + "Tag name: [" << tag << "] is not allowed.\n" + << "Tag names starting with non alpha-numeric characters are reserved\n" + << _lineInfo ); + } + } + + TestCase makeTestCase( ITestInvoker* _testCase, + std::string const& _className, + NameAndTags const& nameAndTags, + SourceLineInfo const& _lineInfo ) + { + bool isHidden = false; + + // Parse out tags + std::vector<std::string> tags; + std::string desc, tag; + bool inTag = false; + std::string _descOrTags = nameAndTags.tags; + for (char c : _descOrTags) { + if( !inTag ) { + if( c == '[' ) + inTag = true; + else + desc += c; + } + else { + if( c == ']' ) { + TestCaseInfo::SpecialProperties prop = parseSpecialTag( tag ); + if( ( prop & TestCaseInfo::IsHidden ) != 0 ) + isHidden = true; + else if( prop == TestCaseInfo::None ) + enforceNotReservedTag( tag, _lineInfo ); + + tags.push_back( tag ); + tag.clear(); + inTag = false; + } + else + tag += c; + } + } + if( isHidden ) { + tags.push_back( "." ); + } + + TestCaseInfo info( nameAndTags.name, _className, desc, tags, _lineInfo ); + return TestCase( _testCase, std::move(info) ); + } + + void setTags( TestCaseInfo& testCaseInfo, std::vector<std::string> tags ) { + std::sort(begin(tags), end(tags)); + tags.erase(std::unique(begin(tags), end(tags)), end(tags)); + testCaseInfo.lcaseTags.clear(); + + for( auto const& tag : tags ) { + std::string lcaseTag = toLower( tag ); + testCaseInfo.properties = static_cast<TestCaseInfo::SpecialProperties>( testCaseInfo.properties | parseSpecialTag( lcaseTag ) ); + testCaseInfo.lcaseTags.push_back( lcaseTag ); + } + testCaseInfo.tags = std::move(tags); + } + + TestCaseInfo::TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::vector<std::string> const& _tags, + SourceLineInfo const& _lineInfo ) + : name( _name ), + className( _className ), + description( _description ), + lineInfo( _lineInfo ), + properties( None ) + { + setTags( *this, _tags ); + } + + bool TestCaseInfo::isHidden() const { + return ( properties & IsHidden ) != 0; + } + bool TestCaseInfo::throws() const { + return ( properties & Throws ) != 0; + } + bool TestCaseInfo::okToFail() const { + return ( properties & (ShouldFail | MayFail ) ) != 0; + } + bool TestCaseInfo::expectedToFail() const { + return ( properties & (ShouldFail ) ) != 0; + } + + std::string TestCaseInfo::tagsAsString() const { + std::string ret; + // '[' and ']' per tag + std::size_t full_size = 2 * tags.size(); + for (const auto& tag : tags) { + full_size += tag.size(); + } + ret.reserve(full_size); + for (const auto& tag : tags) { + ret.push_back('['); + ret.append(tag); + ret.push_back(']'); + } + + return ret; + } + + TestCase::TestCase( ITestInvoker* testCase, TestCaseInfo&& info ) : TestCaseInfo( std::move(info) ), test( testCase ) {} + + TestCase TestCase::withName( std::string const& _newName ) const { + TestCase other( *this ); + other.name = _newName; + return other; + } + + void TestCase::invoke() const { + test->invoke(); + } + + bool TestCase::operator == ( TestCase const& other ) const { + return test.get() == other.test.get() && + name == other.name && + className == other.className; + } + + bool TestCase::operator < ( TestCase const& other ) const { + return name < other.name; + } + + TestCaseInfo const& TestCase::getTestCaseInfo() const + { + return *this; + } + +} // end namespace Catch +// end catch_test_case_info.cpp +// start catch_test_case_registry_impl.cpp + +#include <sstream> + +namespace Catch { + + std::vector<TestCase> sortTests( IConfig const& config, std::vector<TestCase> const& unsortedTestCases ) { + + std::vector<TestCase> sorted = unsortedTestCases; + + switch( config.runOrder() ) { + case RunTests::InLexicographicalOrder: + std::sort( sorted.begin(), sorted.end() ); + break; + case RunTests::InRandomOrder: + seedRng( config ); + std::shuffle( sorted.begin(), sorted.end(), rng() ); + break; + case RunTests::InDeclarationOrder: + // already in declaration order + break; + } + return sorted; + } + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ) { + return testSpec.matches( testCase ) && ( config.allowThrows() || !testCase.throws() ); + } + + void enforceNoDuplicateTestCases( std::vector<TestCase> const& functions ) { + std::set<TestCase> seenFunctions; + for( auto const& function : functions ) { + auto prev = seenFunctions.insert( function ); + CATCH_ENFORCE( prev.second, + "error: TEST_CASE( \"" << function.name << "\" ) already defined.\n" + << "\tFirst seen at " << prev.first->getTestCaseInfo().lineInfo << "\n" + << "\tRedefined at " << function.getTestCaseInfo().lineInfo ); + } + } + + std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config ) { + std::vector<TestCase> filtered; + filtered.reserve( testCases.size() ); + for( auto const& testCase : testCases ) + if( matchTest( testCase, testSpec, config ) ) + filtered.push_back( testCase ); + return filtered; + } + std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config ) { + return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config ); + } + + void TestRegistry::registerTest( TestCase const& testCase ) { + std::string name = testCase.getTestCaseInfo().name; + if( name.empty() ) { + ReusableStringStream rss; + rss << "Anonymous test case " << ++m_unnamedCount; + return registerTest( testCase.withName( rss.str() ) ); + } + m_functions.push_back( testCase ); + } + + std::vector<TestCase> const& TestRegistry::getAllTests() const { + return m_functions; + } + std::vector<TestCase> const& TestRegistry::getAllTestsSorted( IConfig const& config ) const { + if( m_sortedFunctions.empty() ) + enforceNoDuplicateTestCases( m_functions ); + + if( m_currentSortOrder != config.runOrder() || m_sortedFunctions.empty() ) { + m_sortedFunctions = sortTests( config, m_functions ); + m_currentSortOrder = config.runOrder(); + } + return m_sortedFunctions; + } + + /////////////////////////////////////////////////////////////////////////// + TestInvokerAsFunction::TestInvokerAsFunction( void(*testAsFunction)() ) noexcept : m_testAsFunction( testAsFunction ) {} + + void TestInvokerAsFunction::invoke() const { + m_testAsFunction(); + } + + std::string extractClassName( StringRef const& classOrQualifiedMethodName ) { + std::string className = classOrQualifiedMethodName; + if( startsWith( className, '&' ) ) + { + std::size_t lastColons = className.rfind( "::" ); + std::size_t penultimateColons = className.rfind( "::", lastColons-1 ); + if( penultimateColons == std::string::npos ) + penultimateColons = 1; + className = className.substr( penultimateColons, lastColons-penultimateColons ); + } + return className; + } + +} // end namespace Catch +// end catch_test_case_registry_impl.cpp +// start catch_test_case_tracker.cpp + +#include <algorithm> +#include <cassert> +#include <stdexcept> +#include <memory> +#include <sstream> + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +#endif + +namespace Catch { +namespace TestCaseTracking { + + NameAndLocation::NameAndLocation( std::string const& _name, SourceLineInfo const& _location ) + : name( _name ), + location( _location ) + {} + + ITracker::~ITracker() = default; + + TrackerContext& TrackerContext::instance() { + static TrackerContext s_instance; + return s_instance; + } + + ITracker& TrackerContext::startRun() { + m_rootTracker = std::make_shared<SectionTracker>( NameAndLocation( "{root}", CATCH_INTERNAL_LINEINFO ), *this, nullptr ); + m_currentTracker = nullptr; + m_runState = Executing; + return *m_rootTracker; + } + + void TrackerContext::endRun() { + m_rootTracker.reset(); + m_currentTracker = nullptr; + m_runState = NotStarted; + } + + void TrackerContext::startCycle() { + m_currentTracker = m_rootTracker.get(); + m_runState = Executing; + } + void TrackerContext::completeCycle() { + m_runState = CompletedCycle; + } + + bool TrackerContext::completedCycle() const { + return m_runState == CompletedCycle; + } + ITracker& TrackerContext::currentTracker() { + return *m_currentTracker; + } + void TrackerContext::setCurrentTracker( ITracker* tracker ) { + m_currentTracker = tracker; + } + + TrackerBase::TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : m_nameAndLocation( nameAndLocation ), + m_ctx( ctx ), + m_parent( parent ) + {} + + NameAndLocation const& TrackerBase::nameAndLocation() const { + return m_nameAndLocation; + } + bool TrackerBase::isComplete() const { + return m_runState == CompletedSuccessfully || m_runState == Failed; + } + bool TrackerBase::isSuccessfullyCompleted() const { + return m_runState == CompletedSuccessfully; + } + bool TrackerBase::isOpen() const { + return m_runState != NotStarted && !isComplete(); + } + bool TrackerBase::hasChildren() const { + return !m_children.empty(); + } + + void TrackerBase::addChild( ITrackerPtr const& child ) { + m_children.push_back( child ); + } + + ITrackerPtr TrackerBase::findChild( NameAndLocation const& nameAndLocation ) { + auto it = std::find_if( m_children.begin(), m_children.end(), + [&nameAndLocation]( ITrackerPtr const& tracker ){ + return + tracker->nameAndLocation().location == nameAndLocation.location && + tracker->nameAndLocation().name == nameAndLocation.name; + } ); + return( it != m_children.end() ) + ? *it + : nullptr; + } + ITracker& TrackerBase::parent() { + assert( m_parent ); // Should always be non-null except for root + return *m_parent; + } + + void TrackerBase::openChild() { + if( m_runState != ExecutingChildren ) { + m_runState = ExecutingChildren; + if( m_parent ) + m_parent->openChild(); + } + } + + bool TrackerBase::isSectionTracker() const { return false; } + bool TrackerBase::isGeneratorTracker() const { return false; } + + void TrackerBase::open() { + m_runState = Executing; + moveToThis(); + if( m_parent ) + m_parent->openChild(); + } + + void TrackerBase::close() { + + // Close any still open children (e.g. generators) + while( &m_ctx.currentTracker() != this ) + m_ctx.currentTracker().close(); + + switch( m_runState ) { + case NeedsAnotherRun: + break; + + case Executing: + m_runState = CompletedSuccessfully; + break; + case ExecutingChildren: + if( m_children.empty() || m_children.back()->isComplete() ) + m_runState = CompletedSuccessfully; + break; + + case NotStarted: + case CompletedSuccessfully: + case Failed: + CATCH_INTERNAL_ERROR( "Illogical state: " << m_runState ); + + default: + CATCH_INTERNAL_ERROR( "Unknown state: " << m_runState ); + } + moveToParent(); + m_ctx.completeCycle(); + } + void TrackerBase::fail() { + m_runState = Failed; + if( m_parent ) + m_parent->markAsNeedingAnotherRun(); + moveToParent(); + m_ctx.completeCycle(); + } + void TrackerBase::markAsNeedingAnotherRun() { + m_runState = NeedsAnotherRun; + } + + void TrackerBase::moveToParent() { + assert( m_parent ); + m_ctx.setCurrentTracker( m_parent ); + } + void TrackerBase::moveToThis() { + m_ctx.setCurrentTracker( this ); + } + + SectionTracker::SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : TrackerBase( nameAndLocation, ctx, parent ) + { + if( parent ) { + while( !parent->isSectionTracker() ) + parent = &parent->parent(); + + SectionTracker& parentSection = static_cast<SectionTracker&>( *parent ); + addNextFilters( parentSection.m_filters ); + } + } + + bool SectionTracker::isComplete() const { + bool complete = true; + + if ((m_filters.empty() || m_filters[0] == "") || + std::find(m_filters.begin(), m_filters.end(), + m_nameAndLocation.name) != m_filters.end()) + complete = TrackerBase::isComplete(); + return complete; + + } + + bool SectionTracker::isSectionTracker() const { return true; } + + SectionTracker& SectionTracker::acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ) { + std::shared_ptr<SectionTracker> section; + + ITracker& currentTracker = ctx.currentTracker(); + if( ITrackerPtr childTracker = currentTracker.findChild( nameAndLocation ) ) { + assert( childTracker ); + assert( childTracker->isSectionTracker() ); + section = std::static_pointer_cast<SectionTracker>( childTracker ); + } + else { + section = std::make_shared<SectionTracker>( nameAndLocation, ctx, ¤tTracker ); + currentTracker.addChild( section ); + } + if( !ctx.completedCycle() ) + section->tryOpen(); + return *section; + } + + void SectionTracker::tryOpen() { + if( !isComplete() && (m_filters.empty() || m_filters[0].empty() || m_filters[0] == m_nameAndLocation.name ) ) + open(); + } + + void SectionTracker::addInitialFilters( std::vector<std::string> const& filters ) { + if( !filters.empty() ) { + m_filters.push_back(""); // Root - should never be consulted + m_filters.push_back(""); // Test Case - not a section filter + m_filters.insert( m_filters.end(), filters.begin(), filters.end() ); + } + } + void SectionTracker::addNextFilters( std::vector<std::string> const& filters ) { + if( filters.size() > 1 ) + m_filters.insert( m_filters.end(), ++filters.begin(), filters.end() ); + } + +} // namespace TestCaseTracking + +using TestCaseTracking::ITracker; +using TestCaseTracking::TrackerContext; +using TestCaseTracking::SectionTracker; + +} // namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif +// end catch_test_case_tracker.cpp +// start catch_test_registry.cpp + +namespace Catch { + + auto makeTestInvoker( void(*testAsFunction)() ) noexcept -> ITestInvoker* { + return new(std::nothrow) TestInvokerAsFunction( testAsFunction ); + } + + NameAndTags::NameAndTags( StringRef const& name_ , StringRef const& tags_ ) noexcept : name( name_ ), tags( tags_ ) {} + + AutoReg::AutoReg( ITestInvoker* invoker, SourceLineInfo const& lineInfo, StringRef const& classOrMethod, NameAndTags const& nameAndTags ) noexcept { + CATCH_TRY { + getMutableRegistryHub() + .registerTest( + makeTestCase( + invoker, + extractClassName( classOrMethod ), + nameAndTags, + lineInfo)); + } CATCH_CATCH_ALL { + // Do not throw when constructing global objects, instead register the exception to be processed later + getMutableRegistryHub().registerStartupException(); + } + } + + AutoReg::~AutoReg() = default; +} +// end catch_test_registry.cpp +// start catch_test_spec.cpp + +#include <algorithm> +#include <string> +#include <vector> +#include <memory> + +namespace Catch { + + TestSpec::Pattern::~Pattern() = default; + TestSpec::NamePattern::~NamePattern() = default; + TestSpec::TagPattern::~TagPattern() = default; + TestSpec::ExcludedPattern::~ExcludedPattern() = default; + + TestSpec::NamePattern::NamePattern( std::string const& name ) + : m_wildcardPattern( toLower( name ), CaseSensitive::No ) + {} + bool TestSpec::NamePattern::matches( TestCaseInfo const& testCase ) const { + return m_wildcardPattern.matches( toLower( testCase.name ) ); + } + + TestSpec::TagPattern::TagPattern( std::string const& tag ) : m_tag( toLower( tag ) ) {} + bool TestSpec::TagPattern::matches( TestCaseInfo const& testCase ) const { + return std::find(begin(testCase.lcaseTags), + end(testCase.lcaseTags), + m_tag) != end(testCase.lcaseTags); + } + + TestSpec::ExcludedPattern::ExcludedPattern( PatternPtr const& underlyingPattern ) : m_underlyingPattern( underlyingPattern ) {} + bool TestSpec::ExcludedPattern::matches( TestCaseInfo const& testCase ) const { return !m_underlyingPattern->matches( testCase ); } + + bool TestSpec::Filter::matches( TestCaseInfo const& testCase ) const { + // All patterns in a filter must match for the filter to be a match + for( auto const& pattern : m_patterns ) { + if( !pattern->matches( testCase ) ) + return false; + } + return true; + } + + bool TestSpec::hasFilters() const { + return !m_filters.empty(); + } + bool TestSpec::matches( TestCaseInfo const& testCase ) const { + // A TestSpec matches if any filter matches + for( auto const& filter : m_filters ) + if( filter.matches( testCase ) ) + return true; + return false; + } +} +// end catch_test_spec.cpp +// start catch_test_spec_parser.cpp + +namespace Catch { + + TestSpecParser::TestSpecParser( ITagAliasRegistry const& tagAliases ) : m_tagAliases( &tagAliases ) {} + + TestSpecParser& TestSpecParser::parse( std::string const& arg ) { + m_mode = None; + m_exclusion = false; + m_start = std::string::npos; + m_arg = m_tagAliases->expandAliases( arg ); + m_escapeChars.clear(); + for( m_pos = 0; m_pos < m_arg.size(); ++m_pos ) + visitChar( m_arg[m_pos] ); + if( m_mode == Name ) + addPattern<TestSpec::NamePattern>(); + return *this; + } + TestSpec TestSpecParser::testSpec() { + addFilter(); + return m_testSpec; + } + + void TestSpecParser::visitChar( char c ) { + if( m_mode == None ) { + switch( c ) { + case ' ': return; + case '~': m_exclusion = true; return; + case '[': return startNewMode( Tag, ++m_pos ); + case '"': return startNewMode( QuotedName, ++m_pos ); + case '\\': return escape(); + default: startNewMode( Name, m_pos ); break; + } + } + if( m_mode == Name ) { + if( c == ',' ) { + addPattern<TestSpec::NamePattern>(); + addFilter(); + } + else if( c == '[' ) { + if( subString() == "exclude:" ) + m_exclusion = true; + else + addPattern<TestSpec::NamePattern>(); + startNewMode( Tag, ++m_pos ); + } + else if( c == '\\' ) + escape(); + } + else if( m_mode == EscapedName ) + m_mode = Name; + else if( m_mode == QuotedName && c == '"' ) + addPattern<TestSpec::NamePattern>(); + else if( m_mode == Tag && c == ']' ) + addPattern<TestSpec::TagPattern>(); + } + void TestSpecParser::startNewMode( Mode mode, std::size_t start ) { + m_mode = mode; + m_start = start; + } + void TestSpecParser::escape() { + if( m_mode == None ) + m_start = m_pos; + m_mode = EscapedName; + m_escapeChars.push_back( m_pos ); + } + std::string TestSpecParser::subString() const { return m_arg.substr( m_start, m_pos - m_start ); } + + void TestSpecParser::addFilter() { + if( !m_currentFilter.m_patterns.empty() ) { + m_testSpec.m_filters.push_back( m_currentFilter ); + m_currentFilter = TestSpec::Filter(); + } + } + + TestSpec parseTestSpec( std::string const& arg ) { + return TestSpecParser( ITagAliasRegistry::get() ).parse( arg ).testSpec(); + } + +} // namespace Catch +// end catch_test_spec_parser.cpp +// start catch_timer.cpp + +#include <chrono> + +static const uint64_t nanosecondsInSecond = 1000000000; + +namespace Catch { + + auto getCurrentNanosecondsSinceEpoch() -> uint64_t { + return std::chrono::duration_cast<std::chrono::nanoseconds>( std::chrono::high_resolution_clock::now().time_since_epoch() ).count(); + } + + namespace { + auto estimateClockResolution() -> uint64_t { + uint64_t sum = 0; + static const uint64_t iterations = 1000000; + + auto startTime = getCurrentNanosecondsSinceEpoch(); + + for( std::size_t i = 0; i < iterations; ++i ) { + + uint64_t ticks; + uint64_t baseTicks = getCurrentNanosecondsSinceEpoch(); + do { + ticks = getCurrentNanosecondsSinceEpoch(); + } while( ticks == baseTicks ); + + auto delta = ticks - baseTicks; + sum += delta; + + // If we have been calibrating for over 3 seconds -- the clock + // is terrible and we should move on. + // TBD: How to signal that the measured resolution is probably wrong? + if (ticks > startTime + 3 * nanosecondsInSecond) { + return sum / ( i + 1u ); + } + } + + // We're just taking the mean, here. To do better we could take the std. dev and exclude outliers + // - and potentially do more iterations if there's a high variance. + return sum/iterations; + } + } + auto getEstimatedClockResolution() -> uint64_t { + static auto s_resolution = estimateClockResolution(); + return s_resolution; + } + + void Timer::start() { + m_nanoseconds = getCurrentNanosecondsSinceEpoch(); + } + auto Timer::getElapsedNanoseconds() const -> uint64_t { + return getCurrentNanosecondsSinceEpoch() - m_nanoseconds; + } + auto Timer::getElapsedMicroseconds() const -> uint64_t { + return getElapsedNanoseconds()/1000; + } + auto Timer::getElapsedMilliseconds() const -> unsigned int { + return static_cast<unsigned int>(getElapsedMicroseconds()/1000); + } + auto Timer::getElapsedSeconds() const -> double { + return getElapsedMicroseconds()/1000000.0; + } + +} // namespace Catch +// end catch_timer.cpp +// start catch_tostring.cpp + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +# pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +// Enable specific decls locally +#if !defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER) +#define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER +#endif + +#include <cmath> +#include <iomanip> + +namespace Catch { + +namespace Detail { + + const std::string unprintableString = "{?}"; + + namespace { + const int hexThreshold = 255; + + struct Endianness { + enum Arch { Big, Little }; + + static Arch which() { + union _{ + int asInt; + char asChar[sizeof (int)]; + } u; + + u.asInt = 1; + return ( u.asChar[sizeof(int)-1] == 1 ) ? Big : Little; + } + }; + } + + std::string rawMemoryToString( const void *object, std::size_t size ) { + // Reverse order for little endian architectures + int i = 0, end = static_cast<int>( size ), inc = 1; + if( Endianness::which() == Endianness::Little ) { + i = end-1; + end = inc = -1; + } + + unsigned char const *bytes = static_cast<unsigned char const *>(object); + ReusableStringStream rss; + rss << "0x" << std::setfill('0') << std::hex; + for( ; i != end; i += inc ) + rss << std::setw(2) << static_cast<unsigned>(bytes[i]); + return rss.str(); + } +} + +template<typename T> +std::string fpToString( T value, int precision ) { + if (Catch::isnan(value)) { + return "nan"; + } + + ReusableStringStream rss; + rss << std::setprecision( precision ) + << std::fixed + << value; + std::string d = rss.str(); + std::size_t i = d.find_last_not_of( '0' ); + if( i != std::string::npos && i != d.size()-1 ) { + if( d[i] == '.' ) + i++; + d = d.substr( 0, i+1 ); + } + return d; +} + +//// ======================================================= //// +// +// Out-of-line defs for full specialization of StringMaker +// +//// ======================================================= //// + +std::string StringMaker<std::string>::convert(const std::string& str) { + if (!getCurrentContext().getConfig()->showInvisibles()) { + return '"' + str + '"'; + } + + std::string s("\""); + for (char c : str) { + switch (c) { + case '\n': + s.append("\\n"); + break; + case '\t': + s.append("\\t"); + break; + default: + s.push_back(c); + break; + } + } + s.append("\""); + return s; +} + +#ifdef CATCH_CONFIG_CPP17_STRING_VIEW +std::string StringMaker<std::string_view>::convert(std::string_view str) { + return ::Catch::Detail::stringify(std::string{ str }); +} +#endif + +std::string StringMaker<char const*>::convert(char const* str) { + if (str) { + return ::Catch::Detail::stringify(std::string{ str }); + } else { + return{ "{null string}" }; + } +} +std::string StringMaker<char*>::convert(char* str) { + if (str) { + return ::Catch::Detail::stringify(std::string{ str }); + } else { + return{ "{null string}" }; + } +} + +#ifdef CATCH_CONFIG_WCHAR +std::string StringMaker<std::wstring>::convert(const std::wstring& wstr) { + std::string s; + s.reserve(wstr.size()); + for (auto c : wstr) { + s += (c <= 0xff) ? static_cast<char>(c) : '?'; + } + return ::Catch::Detail::stringify(s); +} + +# ifdef CATCH_CONFIG_CPP17_STRING_VIEW +std::string StringMaker<std::wstring_view>::convert(std::wstring_view str) { + return StringMaker<std::wstring>::convert(std::wstring(str)); +} +# endif + +std::string StringMaker<wchar_t const*>::convert(wchar_t const * str) { + if (str) { + return ::Catch::Detail::stringify(std::wstring{ str }); + } else { + return{ "{null string}" }; + } +} +std::string StringMaker<wchar_t *>::convert(wchar_t * str) { + if (str) { + return ::Catch::Detail::stringify(std::wstring{ str }); + } else { + return{ "{null string}" }; + } +} +#endif + +std::string StringMaker<int>::convert(int value) { + return ::Catch::Detail::stringify(static_cast<long long>(value)); +} +std::string StringMaker<long>::convert(long value) { + return ::Catch::Detail::stringify(static_cast<long long>(value)); +} +std::string StringMaker<long long>::convert(long long value) { + ReusableStringStream rss; + rss << value; + if (value > Detail::hexThreshold) { + rss << " (0x" << std::hex << value << ')'; + } + return rss.str(); +} + +std::string StringMaker<unsigned int>::convert(unsigned int value) { + return ::Catch::Detail::stringify(static_cast<unsigned long long>(value)); +} +std::string StringMaker<unsigned long>::convert(unsigned long value) { + return ::Catch::Detail::stringify(static_cast<unsigned long long>(value)); +} +std::string StringMaker<unsigned long long>::convert(unsigned long long value) { + ReusableStringStream rss; + rss << value; + if (value > Detail::hexThreshold) { + rss << " (0x" << std::hex << value << ')'; + } + return rss.str(); +} + +std::string StringMaker<bool>::convert(bool b) { + return b ? "true" : "false"; +} + +std::string StringMaker<signed char>::convert(signed char value) { + if (value == '\r') { + return "'\\r'"; + } else if (value == '\f') { + return "'\\f'"; + } else if (value == '\n') { + return "'\\n'"; + } else if (value == '\t') { + return "'\\t'"; + } else if ('\0' <= value && value < ' ') { + return ::Catch::Detail::stringify(static_cast<unsigned int>(value)); + } else { + char chstr[] = "' '"; + chstr[1] = value; + return chstr; + } +} +std::string StringMaker<char>::convert(char c) { + return ::Catch::Detail::stringify(static_cast<signed char>(c)); +} +std::string StringMaker<unsigned char>::convert(unsigned char c) { + return ::Catch::Detail::stringify(static_cast<char>(c)); +} + +std::string StringMaker<std::nullptr_t>::convert(std::nullptr_t) { + return "nullptr"; +} + +std::string StringMaker<float>::convert(float value) { + return fpToString(value, 5) + 'f'; +} +std::string StringMaker<double>::convert(double value) { + return fpToString(value, 10); +} + +std::string ratio_string<std::atto>::symbol() { return "a"; } +std::string ratio_string<std::femto>::symbol() { return "f"; } +std::string ratio_string<std::pico>::symbol() { return "p"; } +std::string ratio_string<std::nano>::symbol() { return "n"; } +std::string ratio_string<std::micro>::symbol() { return "u"; } +std::string ratio_string<std::milli>::symbol() { return "m"; } + +} // end namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + +// end catch_tostring.cpp +// start catch_totals.cpp + +namespace Catch { + + Counts Counts::operator - ( Counts const& other ) const { + Counts diff; + diff.passed = passed - other.passed; + diff.failed = failed - other.failed; + diff.failedButOk = failedButOk - other.failedButOk; + return diff; + } + + Counts& Counts::operator += ( Counts const& other ) { + passed += other.passed; + failed += other.failed; + failedButOk += other.failedButOk; + return *this; + } + + std::size_t Counts::total() const { + return passed + failed + failedButOk; + } + bool Counts::allPassed() const { + return failed == 0 && failedButOk == 0; + } + bool Counts::allOk() const { + return failed == 0; + } + + Totals Totals::operator - ( Totals const& other ) const { + Totals diff; + diff.assertions = assertions - other.assertions; + diff.testCases = testCases - other.testCases; + return diff; + } + + Totals& Totals::operator += ( Totals const& other ) { + assertions += other.assertions; + testCases += other.testCases; + return *this; + } + + Totals Totals::delta( Totals const& prevTotals ) const { + Totals diff = *this - prevTotals; + if( diff.assertions.failed > 0 ) + ++diff.testCases.failed; + else if( diff.assertions.failedButOk > 0 ) + ++diff.testCases.failedButOk; + else + ++diff.testCases.passed; + return diff; + } + +} +// end catch_totals.cpp +// start catch_uncaught_exceptions.cpp + +#include <exception> + +namespace Catch { + bool uncaught_exceptions() { +#if defined(CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) + return std::uncaught_exceptions() > 0; +#else + return std::uncaught_exception(); +#endif + } +} // end namespace Catch +// end catch_uncaught_exceptions.cpp +// start catch_version.cpp + +#include <ostream> + +namespace Catch { + + Version::Version + ( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _patchNumber, + char const * const _branchName, + unsigned int _buildNumber ) + : majorVersion( _majorVersion ), + minorVersion( _minorVersion ), + patchNumber( _patchNumber ), + branchName( _branchName ), + buildNumber( _buildNumber ) + {} + + std::ostream& operator << ( std::ostream& os, Version const& version ) { + os << version.majorVersion << '.' + << version.minorVersion << '.' + << version.patchNumber; + // branchName is never null -> 0th char is \0 if it is empty + if (version.branchName[0]) { + os << '-' << version.branchName + << '.' << version.buildNumber; + } + return os; + } + + Version const& libraryVersion() { + static Version version( 2, 7, 0, "", 0 ); + return version; + } + +} +// end catch_version.cpp +// start catch_wildcard_pattern.cpp + +#include <sstream> + +namespace Catch { + + WildcardPattern::WildcardPattern( std::string const& pattern, + CaseSensitive::Choice caseSensitivity ) + : m_caseSensitivity( caseSensitivity ), + m_pattern( adjustCase( pattern ) ) + { + if( startsWith( m_pattern, '*' ) ) { + m_pattern = m_pattern.substr( 1 ); + m_wildcard = WildcardAtStart; + } + if( endsWith( m_pattern, '*' ) ) { + m_pattern = m_pattern.substr( 0, m_pattern.size()-1 ); + m_wildcard = static_cast<WildcardPosition>( m_wildcard | WildcardAtEnd ); + } + } + + bool WildcardPattern::matches( std::string const& str ) const { + switch( m_wildcard ) { + case NoWildcard: + return m_pattern == adjustCase( str ); + case WildcardAtStart: + return endsWith( adjustCase( str ), m_pattern ); + case WildcardAtEnd: + return startsWith( adjustCase( str ), m_pattern ); + case WildcardAtBothEnds: + return contains( adjustCase( str ), m_pattern ); + default: + CATCH_INTERNAL_ERROR( "Unknown enum" ); + } + } + + std::string WildcardPattern::adjustCase( std::string const& str ) const { + return m_caseSensitivity == CaseSensitive::No ? toLower( str ) : str; + } +} +// end catch_wildcard_pattern.cpp +// start catch_xmlwriter.cpp + +#include <iomanip> + +using uchar = unsigned char; + +namespace Catch { + +namespace { + + size_t trailingBytes(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return 2; + } + if ((c & 0xF0) == 0xE0) { + return 3; + } + if ((c & 0xF8) == 0xF0) { + return 4; + } + CATCH_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + uint32_t headerValue(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return c & 0x1F; + } + if ((c & 0xF0) == 0xE0) { + return c & 0x0F; + } + if ((c & 0xF8) == 0xF0) { + return c & 0x07; + } + CATCH_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + void hexEscapeChar(std::ostream& os, unsigned char c) { + std::ios_base::fmtflags f(os.flags()); + os << "\\x" + << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast<int>(c); + os.flags(f); + } + +} // anonymous namespace + + XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void XmlEncode::encodeTo( std::ostream& os ) const { + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: http://www.w3.org/TR/xml/#syntax) + + for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { + uchar c = m_str[idx]; + switch (c) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: http://www.w3.org/TR/xml/#syntax + if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') + os << ">"; + else + os << c; + break; + + case '\"': + if (m_forWhat == ForAttributes) + os << """; + else + os << c; + break; + + default: + // Check for control characters and invalid utf-8 + + // Escape control characters in standard ascii + // see http://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { + hexEscapeChar(os, c); + break; + } + + // Plain ASCII: Write it to stream + if (c < 0x7F) { + os << c; + break; + } + + // UTF-8 territory + // Check if the encoding is valid and if it is not, hex escape bytes. + // Important: We do not check the exact decoded values for validity, only the encoding format + // First check that this bytes is a valid lead byte: + // This means that it is not encoded as 1111 1XXX + // Or as 10XX XXXX + if (c < 0xC0 || + c >= 0xF8) { + hexEscapeChar(os, c); + break; + } + + auto encBytes = trailingBytes(c); + // Are there enough bytes left to avoid accessing out-of-bounds memory? + if (idx + encBytes - 1 >= m_str.size()) { + hexEscapeChar(os, c); + break; + } + // The header is valid, check data + // The next encBytes bytes must together be a valid utf-8 + // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) + bool valid = true; + uint32_t value = headerValue(c); + for (std::size_t n = 1; n < encBytes; ++n) { + uchar nc = m_str[idx + n]; + valid &= ((nc & 0xC0) == 0x80); + value = (value << 6) | (nc & 0x3F); + } + + if ( + // Wrong bit pattern of following bytes + (!valid) || + // Overlong encodings + (value < 0x80) || + (0x80 <= value && value < 0x800 && encBytes > 2) || + (0x800 < value && value < 0x10000 && encBytes > 3) || + // Encoded value out of range + (value >= 0x110000) + ) { + hexEscapeChar(os, c); + break; + } + + // If we got here, this is in fact a valid(ish) utf-8 sequence + for (std::size_t n = 0; n < encBytes; ++n) { + os << m_str[idx + n]; + } + idx += encBytes - 1; + break; + } + } + } + + std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) noexcept + : m_writer( other.m_writer ){ + other.m_writer = nullptr; + } + XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) noexcept { + if ( m_writer ) { + m_writer->endElement(); + } + m_writer = other.m_writer; + other.m_writer = nullptr; + return *this; + } + + XmlWriter::ScopedElement::~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { + m_writer->writeText( text, indent ); + return *this; + } + + XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) + { + writeDeclaration(); + } + + XmlWriter::~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& XmlWriter::startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& XmlWriter::endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << "</" << m_tags.back() << ">"; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + XmlWriter& XmlWriter::writeComment( std::string const& text ) { + ensureTagClosed(); + m_os << m_indent << "<!--" << text << "-->"; + m_needsNewline = true; + return *this; + } + + void XmlWriter::writeStylesheetRef( std::string const& url ) { + m_os << "<?xml-stylesheet type=\"text/xsl\" href=\"" << url << "\"?>\n"; + } + + XmlWriter& XmlWriter::writeBlankLine() { + ensureTagClosed(); + m_os << '\n'; + return *this; + } + + void XmlWriter::ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + void XmlWriter::writeDeclaration() { + m_os << "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"; + } + + void XmlWriter::newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } +} +// end catch_xmlwriter.cpp +// start catch_reporter_bases.cpp + +#include <cstring> +#include <cfloat> +#include <cstdio> +#include <cassert> +#include <memory> + +namespace Catch { + void prepareExpandedExpression(AssertionResult& result) { + result.getExpandedExpression(); + } + + // Because formatting using c++ streams is stateful, drop down to C is required + // Alternatively we could use stringstream, but its performance is... not good. + std::string getFormattedDuration( double duration ) { + // Max exponent + 1 is required to represent the whole part + // + 1 for decimal point + // + 3 for the 3 decimal places + // + 1 for null terminator + const std::size_t maxDoubleSize = DBL_MAX_10_EXP + 1 + 1 + 3 + 1; + char buffer[maxDoubleSize]; + + // Save previous errno, to prevent sprintf from overwriting it + ErrnoGuard guard; +#ifdef _MSC_VER + sprintf_s(buffer, "%.3f", duration); +#else + std::sprintf(buffer, "%.3f", duration); +#endif + return std::string(buffer); + } + + TestEventListenerBase::TestEventListenerBase(ReporterConfig const & _config) + :StreamingReporterBase(_config) {} + + std::set<Verbosity> TestEventListenerBase::getSupportedVerbosities() { + return { Verbosity::Quiet, Verbosity::Normal, Verbosity::High }; + } + + void TestEventListenerBase::assertionStarting(AssertionInfo const &) {} + + bool TestEventListenerBase::assertionEnded(AssertionStats const &) { + return false; + } + +} // end namespace Catch +// end catch_reporter_bases.cpp +// start catch_reporter_compact.cpp + +namespace { + +#ifdef CATCH_PLATFORM_MAC + const char* failedString() { return "FAILED"; } + const char* passedString() { return "PASSED"; } +#else + const char* failedString() { return "failed"; } + const char* passedString() { return "passed"; } +#endif + + // Colour::LightGrey + Catch::Colour::Code dimColour() { return Catch::Colour::FileName; } + + std::string bothOrAll( std::size_t count ) { + return count == 1 ? std::string() : + count == 2 ? "both " : "all " ; + } + +} // anon namespace + +namespace Catch { +namespace { +// Colour, message variants: +// - white: No tests ran. +// - red: Failed [both/all] N test cases, failed [both/all] M assertions. +// - white: Passed [both/all] N test cases (no assertions). +// - red: Failed N tests cases, failed M assertions. +// - green: Passed [both/all] N tests cases with M assertions. +void printTotals(std::ostream& out, const Totals& totals) { + if (totals.testCases.total() == 0) { + out << "No tests ran."; + } else if (totals.testCases.failed == totals.testCases.total()) { + Colour colour(Colour::ResultError); + const std::string qualify_assertions_failed = + totals.assertions.failed == totals.assertions.total() ? + bothOrAll(totals.assertions.failed) : std::string(); + out << + "Failed " << bothOrAll(totals.testCases.failed) + << pluralise(totals.testCases.failed, "test case") << ", " + "failed " << qualify_assertions_failed << + pluralise(totals.assertions.failed, "assertion") << '.'; + } else if (totals.assertions.total() == 0) { + out << + "Passed " << bothOrAll(totals.testCases.total()) + << pluralise(totals.testCases.total(), "test case") + << " (no assertions)."; + } else if (totals.assertions.failed) { + Colour colour(Colour::ResultError); + out << + "Failed " << pluralise(totals.testCases.failed, "test case") << ", " + "failed " << pluralise(totals.assertions.failed, "assertion") << '.'; + } else { + Colour colour(Colour::ResultSuccess); + out << + "Passed " << bothOrAll(totals.testCases.passed) + << pluralise(totals.testCases.passed, "test case") << + " with " << pluralise(totals.assertions.passed, "assertion") << '.'; + } +} + +// Implementation of CompactReporter formatting +class AssertionPrinter { +public: + AssertionPrinter& operator= (AssertionPrinter const&) = delete; + AssertionPrinter(AssertionPrinter const&) = delete; + AssertionPrinter(std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages) + : stream(_stream) + , result(_stats.assertionResult) + , messages(_stats.infoMessages) + , itMessage(_stats.infoMessages.begin()) + , printInfoMessages(_printInfoMessages) {} + + void print() { + printSourceInfo(); + + itMessage = messages.begin(); + + switch (result.getResultType()) { + case ResultWas::Ok: + printResultType(Colour::ResultSuccess, passedString()); + printOriginalExpression(); + printReconstructedExpression(); + if (!result.hasExpression()) + printRemainingMessages(Colour::None); + else + printRemainingMessages(); + break; + case ResultWas::ExpressionFailed: + if (result.isOk()) + printResultType(Colour::ResultSuccess, failedString() + std::string(" - but was ok")); + else + printResultType(Colour::Error, failedString()); + printOriginalExpression(); + printReconstructedExpression(); + printRemainingMessages(); + break; + case ResultWas::ThrewException: + printResultType(Colour::Error, failedString()); + printIssue("unexpected exception with message:"); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::FatalErrorCondition: + printResultType(Colour::Error, failedString()); + printIssue("fatal error condition with message:"); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::DidntThrowException: + printResultType(Colour::Error, failedString()); + printIssue("expected exception, got none"); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::Info: + printResultType(Colour::None, "info"); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::Warning: + printResultType(Colour::None, "warning"); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::ExplicitFailure: + printResultType(Colour::Error, failedString()); + printIssue("explicitly"); + printRemainingMessages(Colour::None); + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + printResultType(Colour::Error, "** internal error **"); + break; + } + } + +private: + void printSourceInfo() const { + Colour colourGuard(Colour::FileName); + stream << result.getSourceInfo() << ':'; + } + + void printResultType(Colour::Code colour, std::string const& passOrFail) const { + if (!passOrFail.empty()) { + { + Colour colourGuard(colour); + stream << ' ' << passOrFail; + } + stream << ':'; + } + } + + void printIssue(std::string const& issue) const { + stream << ' ' << issue; + } + + void printExpressionWas() { + if (result.hasExpression()) { + stream << ';'; + { + Colour colour(dimColour()); + stream << " expression was:"; + } + printOriginalExpression(); + } + } + + void printOriginalExpression() const { + if (result.hasExpression()) { + stream << ' ' << result.getExpression(); + } + } + + void printReconstructedExpression() const { + if (result.hasExpandedExpression()) { + { + Colour colour(dimColour()); + stream << " for: "; + } + stream << result.getExpandedExpression(); + } + } + + void printMessage() { + if (itMessage != messages.end()) { + stream << " '" << itMessage->message << '\''; + ++itMessage; + } + } + + void printRemainingMessages(Colour::Code colour = dimColour()) { + if (itMessage == messages.end()) + return; + + // using messages.end() directly yields (or auto) compilation error: + std::vector<MessageInfo>::const_iterator itEnd = messages.end(); + const std::size_t N = static_cast<std::size_t>(std::distance(itMessage, itEnd)); + + { + Colour colourGuard(colour); + stream << " with " << pluralise(N, "message") << ':'; + } + + for (; itMessage != itEnd; ) { + // If this assertion is a warning ignore any INFO messages + if (printInfoMessages || itMessage->type != ResultWas::Info) { + stream << " '" << itMessage->message << '\''; + if (++itMessage != itEnd) { + Colour colourGuard(dimColour()); + stream << " and"; + } + } + } + } + +private: + std::ostream& stream; + AssertionResult const& result; + std::vector<MessageInfo> messages; + std::vector<MessageInfo>::const_iterator itMessage; + bool printInfoMessages; +}; + +} // anon namespace + + std::string CompactReporter::getDescription() { + return "Reports test results on a single line, suitable for IDEs"; + } + + ReporterPreferences CompactReporter::getPreferences() const { + return m_reporterPrefs; + } + + void CompactReporter::noMatchingTestCases( std::string const& spec ) { + stream << "No test cases matched '" << spec << '\'' << std::endl; + } + + void CompactReporter::assertionStarting( AssertionInfo const& ) {} + + bool CompactReporter::assertionEnded( AssertionStats const& _assertionStats ) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool printInfoMessages = true; + + // Drop out if result was successful and we're not printing those + if( !m_config->includeSuccessfulResults() && result.isOk() ) { + if( result.getResultType() != ResultWas::Warning ) + return false; + printInfoMessages = false; + } + + AssertionPrinter printer( stream, _assertionStats, printInfoMessages ); + printer.print(); + + stream << std::endl; + return true; + } + + void CompactReporter::sectionEnded(SectionStats const& _sectionStats) { + if (m_config->showDurations() == ShowDurations::Always) { + stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl; + } + } + + void CompactReporter::testRunEnded( TestRunStats const& _testRunStats ) { + printTotals( stream, _testRunStats.totals ); + stream << '\n' << std::endl; + StreamingReporterBase::testRunEnded( _testRunStats ); + } + + CompactReporter::~CompactReporter() {} + + CATCH_REGISTER_REPORTER( "compact", CompactReporter ) + +} // end namespace Catch +// end catch_reporter_compact.cpp +// start catch_reporter_console.cpp + +#include <cfloat> +#include <cstdio> + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch + // Note that 4062 (not all labels are handled + // and default is missing) is enabled +#endif + +namespace Catch { + +namespace { + +// Formatter impl for ConsoleReporter +class ConsoleAssertionPrinter { +public: + ConsoleAssertionPrinter& operator= (ConsoleAssertionPrinter const&) = delete; + ConsoleAssertionPrinter(ConsoleAssertionPrinter const&) = delete; + ConsoleAssertionPrinter(std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages) + : stream(_stream), + stats(_stats), + result(_stats.assertionResult), + colour(Colour::None), + message(result.getMessage()), + messages(_stats.infoMessages), + printInfoMessages(_printInfoMessages) { + switch (result.getResultType()) { + case ResultWas::Ok: + colour = Colour::Success; + passOrFail = "PASSED"; + //if( result.hasMessage() ) + if (_stats.infoMessages.size() == 1) + messageLabel = "with message"; + if (_stats.infoMessages.size() > 1) + messageLabel = "with messages"; + break; + case ResultWas::ExpressionFailed: + if (result.isOk()) { + colour = Colour::Success; + passOrFail = "FAILED - but was ok"; + } else { + colour = Colour::Error; + passOrFail = "FAILED"; + } + if (_stats.infoMessages.size() == 1) + messageLabel = "with message"; + if (_stats.infoMessages.size() > 1) + messageLabel = "with messages"; + break; + case ResultWas::ThrewException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to unexpected exception with "; + if (_stats.infoMessages.size() == 1) + messageLabel += "message"; + if (_stats.infoMessages.size() > 1) + messageLabel += "messages"; + break; + case ResultWas::FatalErrorCondition: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to a fatal error condition"; + break; + case ResultWas::DidntThrowException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "because no exception was thrown where one was expected"; + break; + case ResultWas::Info: + messageLabel = "info"; + break; + case ResultWas::Warning: + messageLabel = "warning"; + break; + case ResultWas::ExplicitFailure: + passOrFail = "FAILED"; + colour = Colour::Error; + if (_stats.infoMessages.size() == 1) + messageLabel = "explicitly with message"; + if (_stats.infoMessages.size() > 1) + messageLabel = "explicitly with messages"; + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + passOrFail = "** internal error **"; + colour = Colour::Error; + break; + } + } + + void print() const { + printSourceInfo(); + if (stats.totals.assertions.total() > 0) { + printResultType(); + printOriginalExpression(); + printReconstructedExpression(); + } else { + stream << '\n'; + } + printMessage(); + } + +private: + void printResultType() const { + if (!passOrFail.empty()) { + Colour colourGuard(colour); + stream << passOrFail << ":\n"; + } + } + void printOriginalExpression() const { + if (result.hasExpression()) { + Colour colourGuard(Colour::OriginalExpression); + stream << " "; + stream << result.getExpressionInMacro(); + stream << '\n'; + } + } + void printReconstructedExpression() const { + if (result.hasExpandedExpression()) { + stream << "with expansion:\n"; + Colour colourGuard(Colour::ReconstructedExpression); + stream << Column(result.getExpandedExpression()).indent(2) << '\n'; + } + } + void printMessage() const { + if (!messageLabel.empty()) + stream << messageLabel << ':' << '\n'; + for (auto const& msg : messages) { + // If this assertion is a warning ignore any INFO messages + if (printInfoMessages || msg.type != ResultWas::Info) + stream << Column(msg.message).indent(2) << '\n'; + } + } + void printSourceInfo() const { + Colour colourGuard(Colour::FileName); + stream << result.getSourceInfo() << ": "; + } + + std::ostream& stream; + AssertionStats const& stats; + AssertionResult const& result; + Colour::Code colour; + std::string passOrFail; + std::string messageLabel; + std::string message; + std::vector<MessageInfo> messages; + bool printInfoMessages; +}; + +std::size_t makeRatio(std::size_t number, std::size_t total) { + std::size_t ratio = total > 0 ? CATCH_CONFIG_CONSOLE_WIDTH * number / total : 0; + return (ratio == 0 && number > 0) ? 1 : ratio; +} + +std::size_t& findMax(std::size_t& i, std::size_t& j, std::size_t& k) { + if (i > j && i > k) + return i; + else if (j > k) + return j; + else + return k; +} + +struct ColumnInfo { + enum Justification { Left, Right }; + std::string name; + int width; + Justification justification; +}; +struct ColumnBreak {}; +struct RowBreak {}; + +class Duration { + enum class Unit { + Auto, + Nanoseconds, + Microseconds, + Milliseconds, + Seconds, + Minutes + }; + static const uint64_t s_nanosecondsInAMicrosecond = 1000; + static const uint64_t s_nanosecondsInAMillisecond = 1000 * s_nanosecondsInAMicrosecond; + static const uint64_t s_nanosecondsInASecond = 1000 * s_nanosecondsInAMillisecond; + static const uint64_t s_nanosecondsInAMinute = 60 * s_nanosecondsInASecond; + + uint64_t m_inNanoseconds; + Unit m_units; + +public: + explicit Duration(uint64_t inNanoseconds, Unit units = Unit::Auto) + : m_inNanoseconds(inNanoseconds), + m_units(units) { + if (m_units == Unit::Auto) { + if (m_inNanoseconds < s_nanosecondsInAMicrosecond) + m_units = Unit::Nanoseconds; + else if (m_inNanoseconds < s_nanosecondsInAMillisecond) + m_units = Unit::Microseconds; + else if (m_inNanoseconds < s_nanosecondsInASecond) + m_units = Unit::Milliseconds; + else if (m_inNanoseconds < s_nanosecondsInAMinute) + m_units = Unit::Seconds; + else + m_units = Unit::Minutes; + } + + } + + auto value() const -> double { + switch (m_units) { + case Unit::Microseconds: + return m_inNanoseconds / static_cast<double>(s_nanosecondsInAMicrosecond); + case Unit::Milliseconds: + return m_inNanoseconds / static_cast<double>(s_nanosecondsInAMillisecond); + case Unit::Seconds: + return m_inNanoseconds / static_cast<double>(s_nanosecondsInASecond); + case Unit::Minutes: + return m_inNanoseconds / static_cast<double>(s_nanosecondsInAMinute); + default: + return static_cast<double>(m_inNanoseconds); + } + } + auto unitsAsString() const -> std::string { + switch (m_units) { + case Unit::Nanoseconds: + return "ns"; + case Unit::Microseconds: + return "us"; + case Unit::Milliseconds: + return "ms"; + case Unit::Seconds: + return "s"; + case Unit::Minutes: + return "m"; + default: + return "** internal error **"; + } + + } + friend auto operator << (std::ostream& os, Duration const& duration) -> std::ostream& { + return os << duration.value() << " " << duration.unitsAsString(); + } +}; +} // end anon namespace + +class TablePrinter { + std::ostream& m_os; + std::vector<ColumnInfo> m_columnInfos; + std::ostringstream m_oss; + int m_currentColumn = -1; + bool m_isOpen = false; + +public: + TablePrinter( std::ostream& os, std::vector<ColumnInfo> columnInfos ) + : m_os( os ), + m_columnInfos( std::move( columnInfos ) ) {} + + auto columnInfos() const -> std::vector<ColumnInfo> const& { + return m_columnInfos; + } + + void open() { + if (!m_isOpen) { + m_isOpen = true; + *this << RowBreak(); + for (auto const& info : m_columnInfos) + *this << info.name << ColumnBreak(); + *this << RowBreak(); + m_os << Catch::getLineOfChars<'-'>() << "\n"; + } + } + void close() { + if (m_isOpen) { + *this << RowBreak(); + m_os << std::endl; + m_isOpen = false; + } + } + + template<typename T> + friend TablePrinter& operator << (TablePrinter& tp, T const& value) { + tp.m_oss << value; + return tp; + } + + friend TablePrinter& operator << (TablePrinter& tp, ColumnBreak) { + auto colStr = tp.m_oss.str(); + // This takes account of utf8 encodings + auto strSize = Catch::StringRef(colStr).numberOfCharacters(); + tp.m_oss.str(""); + tp.open(); + if (tp.m_currentColumn == static_cast<int>(tp.m_columnInfos.size() - 1)) { + tp.m_currentColumn = -1; + tp.m_os << "\n"; + } + tp.m_currentColumn++; + + auto colInfo = tp.m_columnInfos[tp.m_currentColumn]; + auto padding = (strSize + 2 < static_cast<std::size_t>(colInfo.width)) + ? std::string(colInfo.width - (strSize + 2), ' ') + : std::string(); + if (colInfo.justification == ColumnInfo::Left) + tp.m_os << colStr << padding << " "; + else + tp.m_os << padding << colStr << " "; + return tp; + } + + friend TablePrinter& operator << (TablePrinter& tp, RowBreak) { + if (tp.m_currentColumn > 0) { + tp.m_os << "\n"; + tp.m_currentColumn = -1; + } + return tp; + } +}; + +ConsoleReporter::ConsoleReporter(ReporterConfig const& config) + : StreamingReporterBase(config), + m_tablePrinter(new TablePrinter(config.stream(), + { + { "benchmark name", CATCH_CONFIG_CONSOLE_WIDTH - 32, ColumnInfo::Left }, + { "iters", 8, ColumnInfo::Right }, + { "elapsed ns", 14, ColumnInfo::Right }, + { "average", 14, ColumnInfo::Right } + })) {} +ConsoleReporter::~ConsoleReporter() = default; + +std::string ConsoleReporter::getDescription() { + return "Reports test results as plain lines of text"; +} + +void ConsoleReporter::noMatchingTestCases(std::string const& spec) { + stream << "No test cases matched '" << spec << '\'' << std::endl; +} + +void ConsoleReporter::assertionStarting(AssertionInfo const&) {} + +bool ConsoleReporter::assertionEnded(AssertionStats const& _assertionStats) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool includeResults = m_config->includeSuccessfulResults() || !result.isOk(); + + // Drop out if result was successful but we're not printing them. + if (!includeResults && result.getResultType() != ResultWas::Warning) + return false; + + lazyPrint(); + + ConsoleAssertionPrinter printer(stream, _assertionStats, includeResults); + printer.print(); + stream << std::endl; + return true; +} + +void ConsoleReporter::sectionStarting(SectionInfo const& _sectionInfo) { + m_headerPrinted = false; + StreamingReporterBase::sectionStarting(_sectionInfo); +} +void ConsoleReporter::sectionEnded(SectionStats const& _sectionStats) { + m_tablePrinter->close(); + if (_sectionStats.missingAssertions) { + lazyPrint(); + Colour colour(Colour::ResultError); + if (m_sectionStack.size() > 1) + stream << "\nNo assertions in section"; + else + stream << "\nNo assertions in test case"; + stream << " '" << _sectionStats.sectionInfo.name << "'\n" << std::endl; + } + if (m_config->showDurations() == ShowDurations::Always) { + stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl; + } + if (m_headerPrinted) { + m_headerPrinted = false; + } + StreamingReporterBase::sectionEnded(_sectionStats); +} + +void ConsoleReporter::benchmarkStarting(BenchmarkInfo const& info) { + lazyPrintWithoutClosingBenchmarkTable(); + + auto nameCol = Column( info.name ).width( static_cast<std::size_t>( m_tablePrinter->columnInfos()[0].width - 2 ) ); + + bool firstLine = true; + for (auto line : nameCol) { + if (!firstLine) + (*m_tablePrinter) << ColumnBreak() << ColumnBreak() << ColumnBreak(); + else + firstLine = false; + + (*m_tablePrinter) << line << ColumnBreak(); + } +} +void ConsoleReporter::benchmarkEnded(BenchmarkStats const& stats) { + Duration average(stats.elapsedTimeInNanoseconds / stats.iterations); + (*m_tablePrinter) + << stats.iterations << ColumnBreak() + << stats.elapsedTimeInNanoseconds << ColumnBreak() + << average << ColumnBreak(); +} + +void ConsoleReporter::testCaseEnded(TestCaseStats const& _testCaseStats) { + m_tablePrinter->close(); + StreamingReporterBase::testCaseEnded(_testCaseStats); + m_headerPrinted = false; +} +void ConsoleReporter::testGroupEnded(TestGroupStats const& _testGroupStats) { + if (currentGroupInfo.used) { + printSummaryDivider(); + stream << "Summary for group '" << _testGroupStats.groupInfo.name << "':\n"; + printTotals(_testGroupStats.totals); + stream << '\n' << std::endl; + } + StreamingReporterBase::testGroupEnded(_testGroupStats); +} +void ConsoleReporter::testRunEnded(TestRunStats const& _testRunStats) { + printTotalsDivider(_testRunStats.totals); + printTotals(_testRunStats.totals); + stream << std::endl; + StreamingReporterBase::testRunEnded(_testRunStats); +} + +void ConsoleReporter::lazyPrint() { + + m_tablePrinter->close(); + lazyPrintWithoutClosingBenchmarkTable(); +} + +void ConsoleReporter::lazyPrintWithoutClosingBenchmarkTable() { + + if (!currentTestRunInfo.used) + lazyPrintRunInfo(); + if (!currentGroupInfo.used) + lazyPrintGroupInfo(); + + if (!m_headerPrinted) { + printTestCaseAndSectionHeader(); + m_headerPrinted = true; + } +} +void ConsoleReporter::lazyPrintRunInfo() { + stream << '\n' << getLineOfChars<'~'>() << '\n'; + Colour colour(Colour::SecondaryText); + stream << currentTestRunInfo->name + << " is a Catch v" << libraryVersion() << " host application.\n" + << "Run with -? for options\n\n"; + + if (m_config->rngSeed() != 0) + stream << "Randomness seeded to: " << m_config->rngSeed() << "\n\n"; + + currentTestRunInfo.used = true; +} +void ConsoleReporter::lazyPrintGroupInfo() { + if (!currentGroupInfo->name.empty() && currentGroupInfo->groupsCounts > 1) { + printClosedHeader("Group: " + currentGroupInfo->name); + currentGroupInfo.used = true; + } +} +void ConsoleReporter::printTestCaseAndSectionHeader() { + assert(!m_sectionStack.empty()); + printOpenHeader(currentTestCaseInfo->name); + + if (m_sectionStack.size() > 1) { + Colour colourGuard(Colour::Headers); + + auto + it = m_sectionStack.begin() + 1, // Skip first section (test case) + itEnd = m_sectionStack.end(); + for (; it != itEnd; ++it) + printHeaderString(it->name, 2); + } + + SourceLineInfo lineInfo = m_sectionStack.back().lineInfo; + + if (!lineInfo.empty()) { + stream << getLineOfChars<'-'>() << '\n'; + Colour colourGuard(Colour::FileName); + stream << lineInfo << '\n'; + } + stream << getLineOfChars<'.'>() << '\n' << std::endl; +} + +void ConsoleReporter::printClosedHeader(std::string const& _name) { + printOpenHeader(_name); + stream << getLineOfChars<'.'>() << '\n'; +} +void ConsoleReporter::printOpenHeader(std::string const& _name) { + stream << getLineOfChars<'-'>() << '\n'; + { + Colour colourGuard(Colour::Headers); + printHeaderString(_name); + } +} + +// if string has a : in first line will set indent to follow it on +// subsequent lines +void ConsoleReporter::printHeaderString(std::string const& _string, std::size_t indent) { + std::size_t i = _string.find(": "); + if (i != std::string::npos) + i += 2; + else + i = 0; + stream << Column(_string).indent(indent + i).initialIndent(indent) << '\n'; +} + +struct SummaryColumn { + + SummaryColumn( std::string _label, Colour::Code _colour ) + : label( std::move( _label ) ), + colour( _colour ) {} + SummaryColumn addRow( std::size_t count ) { + ReusableStringStream rss; + rss << count; + std::string row = rss.str(); + for (auto& oldRow : rows) { + while (oldRow.size() < row.size()) + oldRow = ' ' + oldRow; + while (oldRow.size() > row.size()) + row = ' ' + row; + } + rows.push_back(row); + return *this; + } + + std::string label; + Colour::Code colour; + std::vector<std::string> rows; + +}; + +void ConsoleReporter::printTotals( Totals const& totals ) { + if (totals.testCases.total() == 0) { + stream << Colour(Colour::Warning) << "No tests ran\n"; + } else if (totals.assertions.total() > 0 && totals.testCases.allPassed()) { + stream << Colour(Colour::ResultSuccess) << "All tests passed"; + stream << " (" + << pluralise(totals.assertions.passed, "assertion") << " in " + << pluralise(totals.testCases.passed, "test case") << ')' + << '\n'; + } else { + + std::vector<SummaryColumn> columns; + columns.push_back(SummaryColumn("", Colour::None) + .addRow(totals.testCases.total()) + .addRow(totals.assertions.total())); + columns.push_back(SummaryColumn("passed", Colour::Success) + .addRow(totals.testCases.passed) + .addRow(totals.assertions.passed)); + columns.push_back(SummaryColumn("failed", Colour::ResultError) + .addRow(totals.testCases.failed) + .addRow(totals.assertions.failed)); + columns.push_back(SummaryColumn("failed as expected", Colour::ResultExpectedFailure) + .addRow(totals.testCases.failedButOk) + .addRow(totals.assertions.failedButOk)); + + printSummaryRow("test cases", columns, 0); + printSummaryRow("assertions", columns, 1); + } +} +void ConsoleReporter::printSummaryRow(std::string const& label, std::vector<SummaryColumn> const& cols, std::size_t row) { + for (auto col : cols) { + std::string value = col.rows[row]; + if (col.label.empty()) { + stream << label << ": "; + if (value != "0") + stream << value; + else + stream << Colour(Colour::Warning) << "- none -"; + } else if (value != "0") { + stream << Colour(Colour::LightGrey) << " | "; + stream << Colour(col.colour) + << value << ' ' << col.label; + } + } + stream << '\n'; +} + +void ConsoleReporter::printTotalsDivider(Totals const& totals) { + if (totals.testCases.total() > 0) { + std::size_t failedRatio = makeRatio(totals.testCases.failed, totals.testCases.total()); + std::size_t failedButOkRatio = makeRatio(totals.testCases.failedButOk, totals.testCases.total()); + std::size_t passedRatio = makeRatio(totals.testCases.passed, totals.testCases.total()); + while (failedRatio + failedButOkRatio + passedRatio < CATCH_CONFIG_CONSOLE_WIDTH - 1) + findMax(failedRatio, failedButOkRatio, passedRatio)++; + while (failedRatio + failedButOkRatio + passedRatio > CATCH_CONFIG_CONSOLE_WIDTH - 1) + findMax(failedRatio, failedButOkRatio, passedRatio)--; + + stream << Colour(Colour::Error) << std::string(failedRatio, '='); + stream << Colour(Colour::ResultExpectedFailure) << std::string(failedButOkRatio, '='); + if (totals.testCases.allPassed()) + stream << Colour(Colour::ResultSuccess) << std::string(passedRatio, '='); + else + stream << Colour(Colour::Success) << std::string(passedRatio, '='); + } else { + stream << Colour(Colour::Warning) << std::string(CATCH_CONFIG_CONSOLE_WIDTH - 1, '='); + } + stream << '\n'; +} +void ConsoleReporter::printSummaryDivider() { + stream << getLineOfChars<'-'>() << '\n'; +} + +CATCH_REGISTER_REPORTER("console", ConsoleReporter) + +} // end namespace Catch + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +// end catch_reporter_console.cpp +// start catch_reporter_junit.cpp + +#include <cassert> +#include <sstream> +#include <ctime> +#include <algorithm> + +namespace Catch { + + namespace { + std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + +#ifdef _MSC_VER + std::tm timeInfo = {}; + gmtime_s(&timeInfo, &rawtime); +#else + std::tm* timeInfo; + timeInfo = std::gmtime(&rawtime); +#endif + + char timeStamp[timeStampSize]; + const char * const fmt = "%Y-%m-%dT%H:%M:%SZ"; + +#ifdef _MSC_VER + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); +#else + std::strftime(timeStamp, timeStampSize, fmt, timeInfo); +#endif + return std::string(timeStamp); + } + + std::string fileNameTag(const std::vector<std::string> &tags) { + auto it = std::find_if(begin(tags), + end(tags), + [] (std::string const& tag) {return tag.front() == '#'; }); + if (it != tags.end()) + return it->substr(1); + return std::string(); + } + } // anonymous namespace + + JunitReporter::JunitReporter( ReporterConfig const& _config ) + : CumulativeReporterBase( _config ), + xml( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = true; + m_reporterPrefs.shouldReportAllAssertions = true; + } + + JunitReporter::~JunitReporter() {} + + std::string JunitReporter::getDescription() { + return "Reports test results in an XML format that looks like Ant's junitreport target"; + } + + void JunitReporter::noMatchingTestCases( std::string const& /*spec*/ ) {} + + void JunitReporter::testRunStarting( TestRunInfo const& runInfo ) { + CumulativeReporterBase::testRunStarting( runInfo ); + xml.startElement( "testsuites" ); + if( m_config->rngSeed() != 0 ) { + xml.startElement( "properties" ); + xml.scopedElement( "property" ) + .writeAttribute( "name", "random-seed" ) + .writeAttribute( "value", m_config->rngSeed() ); + xml.endElement(); + } + } + + void JunitReporter::testGroupStarting( GroupInfo const& groupInfo ) { + suiteTimer.start(); + stdOutForSuite.clear(); + stdErrForSuite.clear(); + unexpectedExceptions = 0; + CumulativeReporterBase::testGroupStarting( groupInfo ); + } + + void JunitReporter::testCaseStarting( TestCaseInfo const& testCaseInfo ) { + m_okToFail = testCaseInfo.okToFail(); + } + + bool JunitReporter::assertionEnded( AssertionStats const& assertionStats ) { + if( assertionStats.assertionResult.getResultType() == ResultWas::ThrewException && !m_okToFail ) + unexpectedExceptions++; + return CumulativeReporterBase::assertionEnded( assertionStats ); + } + + void JunitReporter::testCaseEnded( TestCaseStats const& testCaseStats ) { + stdOutForSuite += testCaseStats.stdOut; + stdErrForSuite += testCaseStats.stdErr; + CumulativeReporterBase::testCaseEnded( testCaseStats ); + } + + void JunitReporter::testGroupEnded( TestGroupStats const& testGroupStats ) { + double suiteTime = suiteTimer.getElapsedSeconds(); + CumulativeReporterBase::testGroupEnded( testGroupStats ); + writeGroup( *m_testGroups.back(), suiteTime ); + } + + void JunitReporter::testRunEndedCumulative() { + xml.endElement(); + } + + void JunitReporter::writeGroup( TestGroupNode const& groupNode, double suiteTime ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testsuite" ); + TestGroupStats const& stats = groupNode.value; + xml.writeAttribute( "name", stats.groupInfo.name ); + xml.writeAttribute( "errors", unexpectedExceptions ); + xml.writeAttribute( "failures", stats.totals.assertions.failed-unexpectedExceptions ); + xml.writeAttribute( "tests", stats.totals.assertions.total() ); + xml.writeAttribute( "hostname", "tbd" ); // !TBD + if( m_config->showDurations() == ShowDurations::Never ) + xml.writeAttribute( "time", "" ); + else + xml.writeAttribute( "time", suiteTime ); + xml.writeAttribute( "timestamp", getCurrentTimestamp() ); + + // Write test cases + for( auto const& child : groupNode.children ) + writeTestCase( *child ); + + xml.scopedElement( "system-out" ).writeText( trim( stdOutForSuite ), false ); + xml.scopedElement( "system-err" ).writeText( trim( stdErrForSuite ), false ); + } + + void JunitReporter::writeTestCase( TestCaseNode const& testCaseNode ) { + TestCaseStats const& stats = testCaseNode.value; + + // All test cases have exactly one section - which represents the + // test case itself. That section may have 0-n nested sections + assert( testCaseNode.children.size() == 1 ); + SectionNode const& rootSection = *testCaseNode.children.front(); + + std::string className = stats.testInfo.className; + + if( className.empty() ) { + className = fileNameTag(stats.testInfo.tags); + if ( className.empty() ) + className = "global"; + } + + if ( !m_config->name().empty() ) + className = m_config->name() + "." + className; + + writeSection( className, "", rootSection ); + } + + void JunitReporter::writeSection( std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode ) { + std::string name = trim( sectionNode.stats.sectionInfo.name ); + if( !rootName.empty() ) + name = rootName + '/' + name; + + if( !sectionNode.assertions.empty() || + !sectionNode.stdOut.empty() || + !sectionNode.stdErr.empty() ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testcase" ); + if( className.empty() ) { + xml.writeAttribute( "classname", name ); + xml.writeAttribute( "name", "root" ); + } + else { + xml.writeAttribute( "classname", className ); + xml.writeAttribute( "name", name ); + } + xml.writeAttribute( "time", ::Catch::Detail::stringify( sectionNode.stats.durationInSeconds ) ); + + writeAssertions( sectionNode ); + + if( !sectionNode.stdOut.empty() ) + xml.scopedElement( "system-out" ).writeText( trim( sectionNode.stdOut ), false ); + if( !sectionNode.stdErr.empty() ) + xml.scopedElement( "system-err" ).writeText( trim( sectionNode.stdErr ), false ); + } + for( auto const& childNode : sectionNode.childSections ) + if( className.empty() ) + writeSection( name, "", *childNode ); + else + writeSection( className, name, *childNode ); + } + + void JunitReporter::writeAssertions( SectionNode const& sectionNode ) { + for( auto const& assertion : sectionNode.assertions ) + writeAssertion( assertion ); + } + + void JunitReporter::writeAssertion( AssertionStats const& stats ) { + AssertionResult const& result = stats.assertionResult; + if( !result.isOk() ) { + std::string elementName; + switch( result.getResultType() ) { + case ResultWas::ThrewException: + case ResultWas::FatalErrorCondition: + elementName = "error"; + break; + case ResultWas::ExplicitFailure: + elementName = "failure"; + break; + case ResultWas::ExpressionFailed: + elementName = "failure"; + break; + case ResultWas::DidntThrowException: + elementName = "failure"; + break; + + // We should never see these here: + case ResultWas::Info: + case ResultWas::Warning: + case ResultWas::Ok: + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + elementName = "internalError"; + break; + } + + XmlWriter::ScopedElement e = xml.scopedElement( elementName ); + + xml.writeAttribute( "message", result.getExpandedExpression() ); + xml.writeAttribute( "type", result.getTestMacroName() ); + + ReusableStringStream rss; + if( !result.getMessage().empty() ) + rss << result.getMessage() << '\n'; + for( auto const& msg : stats.infoMessages ) + if( msg.type == ResultWas::Info ) + rss << msg.message << '\n'; + + rss << "at " << result.getSourceInfo(); + xml.writeText( rss.str(), false ); + } + } + + CATCH_REGISTER_REPORTER( "junit", JunitReporter ) + +} // end namespace Catch +// end catch_reporter_junit.cpp +// start catch_reporter_listening.cpp + +#include <cassert> + +namespace Catch { + + ListeningReporter::ListeningReporter() { + // We will assume that listeners will always want all assertions + m_preferences.shouldReportAllAssertions = true; + } + + void ListeningReporter::addListener( IStreamingReporterPtr&& listener ) { + m_listeners.push_back( std::move( listener ) ); + } + + void ListeningReporter::addReporter(IStreamingReporterPtr&& reporter) { + assert(!m_reporter && "Listening reporter can wrap only 1 real reporter"); + m_reporter = std::move( reporter ); + m_preferences.shouldRedirectStdOut = m_reporter->getPreferences().shouldRedirectStdOut; + } + + ReporterPreferences ListeningReporter::getPreferences() const { + return m_preferences; + } + + std::set<Verbosity> ListeningReporter::getSupportedVerbosities() { + return std::set<Verbosity>{ }; + } + + void ListeningReporter::noMatchingTestCases( std::string const& spec ) { + for ( auto const& listener : m_listeners ) { + listener->noMatchingTestCases( spec ); + } + m_reporter->noMatchingTestCases( spec ); + } + + void ListeningReporter::benchmarkStarting( BenchmarkInfo const& benchmarkInfo ) { + for ( auto const& listener : m_listeners ) { + listener->benchmarkStarting( benchmarkInfo ); + } + m_reporter->benchmarkStarting( benchmarkInfo ); + } + void ListeningReporter::benchmarkEnded( BenchmarkStats const& benchmarkStats ) { + for ( auto const& listener : m_listeners ) { + listener->benchmarkEnded( benchmarkStats ); + } + m_reporter->benchmarkEnded( benchmarkStats ); + } + + void ListeningReporter::testRunStarting( TestRunInfo const& testRunInfo ) { + for ( auto const& listener : m_listeners ) { + listener->testRunStarting( testRunInfo ); + } + m_reporter->testRunStarting( testRunInfo ); + } + + void ListeningReporter::testGroupStarting( GroupInfo const& groupInfo ) { + for ( auto const& listener : m_listeners ) { + listener->testGroupStarting( groupInfo ); + } + m_reporter->testGroupStarting( groupInfo ); + } + + void ListeningReporter::testCaseStarting( TestCaseInfo const& testInfo ) { + for ( auto const& listener : m_listeners ) { + listener->testCaseStarting( testInfo ); + } + m_reporter->testCaseStarting( testInfo ); + } + + void ListeningReporter::sectionStarting( SectionInfo const& sectionInfo ) { + for ( auto const& listener : m_listeners ) { + listener->sectionStarting( sectionInfo ); + } + m_reporter->sectionStarting( sectionInfo ); + } + + void ListeningReporter::assertionStarting( AssertionInfo const& assertionInfo ) { + for ( auto const& listener : m_listeners ) { + listener->assertionStarting( assertionInfo ); + } + m_reporter->assertionStarting( assertionInfo ); + } + + // The return value indicates if the messages buffer should be cleared: + bool ListeningReporter::assertionEnded( AssertionStats const& assertionStats ) { + for( auto const& listener : m_listeners ) { + static_cast<void>( listener->assertionEnded( assertionStats ) ); + } + return m_reporter->assertionEnded( assertionStats ); + } + + void ListeningReporter::sectionEnded( SectionStats const& sectionStats ) { + for ( auto const& listener : m_listeners ) { + listener->sectionEnded( sectionStats ); + } + m_reporter->sectionEnded( sectionStats ); + } + + void ListeningReporter::testCaseEnded( TestCaseStats const& testCaseStats ) { + for ( auto const& listener : m_listeners ) { + listener->testCaseEnded( testCaseStats ); + } + m_reporter->testCaseEnded( testCaseStats ); + } + + void ListeningReporter::testGroupEnded( TestGroupStats const& testGroupStats ) { + for ( auto const& listener : m_listeners ) { + listener->testGroupEnded( testGroupStats ); + } + m_reporter->testGroupEnded( testGroupStats ); + } + + void ListeningReporter::testRunEnded( TestRunStats const& testRunStats ) { + for ( auto const& listener : m_listeners ) { + listener->testRunEnded( testRunStats ); + } + m_reporter->testRunEnded( testRunStats ); + } + + void ListeningReporter::skipTest( TestCaseInfo const& testInfo ) { + for ( auto const& listener : m_listeners ) { + listener->skipTest( testInfo ); + } + m_reporter->skipTest( testInfo ); + } + + bool ListeningReporter::isMulti() const { + return true; + } + +} // end namespace Catch +// end catch_reporter_listening.cpp +// start catch_reporter_xml.cpp + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch + // Note that 4062 (not all labels are handled + // and default is missing) is enabled +#endif + +namespace Catch { + XmlReporter::XmlReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ), + m_xml(_config.stream()) + { + m_reporterPrefs.shouldRedirectStdOut = true; + m_reporterPrefs.shouldReportAllAssertions = true; + } + + XmlReporter::~XmlReporter() = default; + + std::string XmlReporter::getDescription() { + return "Reports test results as an XML document"; + } + + std::string XmlReporter::getStylesheetRef() const { + return std::string(); + } + + void XmlReporter::writeSourceInfo( SourceLineInfo const& sourceInfo ) { + m_xml + .writeAttribute( "filename", sourceInfo.file ) + .writeAttribute( "line", sourceInfo.line ); + } + + void XmlReporter::noMatchingTestCases( std::string const& s ) { + StreamingReporterBase::noMatchingTestCases( s ); + } + + void XmlReporter::testRunStarting( TestRunInfo const& testInfo ) { + StreamingReporterBase::testRunStarting( testInfo ); + std::string stylesheetRef = getStylesheetRef(); + if( !stylesheetRef.empty() ) + m_xml.writeStylesheetRef( stylesheetRef ); + m_xml.startElement( "Catch" ); + if( !m_config->name().empty() ) + m_xml.writeAttribute( "name", m_config->name() ); + if( m_config->rngSeed() != 0 ) + m_xml.scopedElement( "Randomness" ) + .writeAttribute( "seed", m_config->rngSeed() ); + } + + void XmlReporter::testGroupStarting( GroupInfo const& groupInfo ) { + StreamingReporterBase::testGroupStarting( groupInfo ); + m_xml.startElement( "Group" ) + .writeAttribute( "name", groupInfo.name ); + } + + void XmlReporter::testCaseStarting( TestCaseInfo const& testInfo ) { + StreamingReporterBase::testCaseStarting(testInfo); + m_xml.startElement( "TestCase" ) + .writeAttribute( "name", trim( testInfo.name ) ) + .writeAttribute( "description", testInfo.description ) + .writeAttribute( "tags", testInfo.tagsAsString() ); + + writeSourceInfo( testInfo.lineInfo ); + + if ( m_config->showDurations() == ShowDurations::Always ) + m_testCaseTimer.start(); + m_xml.ensureTagClosed(); + } + + void XmlReporter::sectionStarting( SectionInfo const& sectionInfo ) { + StreamingReporterBase::sectionStarting( sectionInfo ); + if( m_sectionDepth++ > 0 ) { + m_xml.startElement( "Section" ) + .writeAttribute( "name", trim( sectionInfo.name ) ); + writeSourceInfo( sectionInfo.lineInfo ); + m_xml.ensureTagClosed(); + } + } + + void XmlReporter::assertionStarting( AssertionInfo const& ) { } + + bool XmlReporter::assertionEnded( AssertionStats const& assertionStats ) { + + AssertionResult const& result = assertionStats.assertionResult; + + bool includeResults = m_config->includeSuccessfulResults() || !result.isOk(); + + if( includeResults || result.getResultType() == ResultWas::Warning ) { + // Print any info messages in <Info> tags. + for( auto const& msg : assertionStats.infoMessages ) { + if( msg.type == ResultWas::Info && includeResults ) { + m_xml.scopedElement( "Info" ) + .writeText( msg.message ); + } else if ( msg.type == ResultWas::Warning ) { + m_xml.scopedElement( "Warning" ) + .writeText( msg.message ); + } + } + } + + // Drop out if result was successful but we're not printing them. + if( !includeResults && result.getResultType() != ResultWas::Warning ) + return true; + + // Print the expression if there is one. + if( result.hasExpression() ) { + m_xml.startElement( "Expression" ) + .writeAttribute( "success", result.succeeded() ) + .writeAttribute( "type", result.getTestMacroName() ); + + writeSourceInfo( result.getSourceInfo() ); + + m_xml.scopedElement( "Original" ) + .writeText( result.getExpression() ); + m_xml.scopedElement( "Expanded" ) + .writeText( result.getExpandedExpression() ); + } + + // And... Print a result applicable to each result type. + switch( result.getResultType() ) { + case ResultWas::ThrewException: + m_xml.startElement( "Exception" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + case ResultWas::FatalErrorCondition: + m_xml.startElement( "FatalErrorCondition" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + case ResultWas::Info: + m_xml.scopedElement( "Info" ) + .writeText( result.getMessage() ); + break; + case ResultWas::Warning: + // Warning will already have been written + break; + case ResultWas::ExplicitFailure: + m_xml.startElement( "Failure" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + default: + break; + } + + if( result.hasExpression() ) + m_xml.endElement(); + + return true; + } + + void XmlReporter::sectionEnded( SectionStats const& sectionStats ) { + StreamingReporterBase::sectionEnded( sectionStats ); + if( --m_sectionDepth > 0 ) { + XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResults" ); + e.writeAttribute( "successes", sectionStats.assertions.passed ); + e.writeAttribute( "failures", sectionStats.assertions.failed ); + e.writeAttribute( "expectedFailures", sectionStats.assertions.failedButOk ); + + if ( m_config->showDurations() == ShowDurations::Always ) + e.writeAttribute( "durationInSeconds", sectionStats.durationInSeconds ); + + m_xml.endElement(); + } + } + + void XmlReporter::testCaseEnded( TestCaseStats const& testCaseStats ) { + StreamingReporterBase::testCaseEnded( testCaseStats ); + XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResult" ); + e.writeAttribute( "success", testCaseStats.totals.assertions.allOk() ); + + if ( m_config->showDurations() == ShowDurations::Always ) + e.writeAttribute( "durationInSeconds", m_testCaseTimer.getElapsedSeconds() ); + + if( !testCaseStats.stdOut.empty() ) + m_xml.scopedElement( "StdOut" ).writeText( trim( testCaseStats.stdOut ), false ); + if( !testCaseStats.stdErr.empty() ) + m_xml.scopedElement( "StdErr" ).writeText( trim( testCaseStats.stdErr ), false ); + + m_xml.endElement(); + } + + void XmlReporter::testGroupEnded( TestGroupStats const& testGroupStats ) { + StreamingReporterBase::testGroupEnded( testGroupStats ); + // TODO: Check testGroupStats.aborting and act accordingly. + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", testGroupStats.totals.assertions.passed ) + .writeAttribute( "failures", testGroupStats.totals.assertions.failed ) + .writeAttribute( "expectedFailures", testGroupStats.totals.assertions.failedButOk ); + m_xml.endElement(); + } + + void XmlReporter::testRunEnded( TestRunStats const& testRunStats ) { + StreamingReporterBase::testRunEnded( testRunStats ); + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", testRunStats.totals.assertions.passed ) + .writeAttribute( "failures", testRunStats.totals.assertions.failed ) + .writeAttribute( "expectedFailures", testRunStats.totals.assertions.failedButOk ); + m_xml.endElement(); + } + + CATCH_REGISTER_REPORTER( "xml", XmlReporter ) + +} // end namespace Catch + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +// end catch_reporter_xml.cpp + +namespace Catch { + LeakDetector leakDetector; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_impl.hpp +#endif + +#ifdef CATCH_CONFIG_MAIN +// start catch_default_main.hpp + +#ifndef __OBJC__ + +#if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(_UNICODE) && !defined(DO_NOT_USE_WMAIN) +// Standard C/C++ Win32 Unicode wmain entry point +extern "C" int wmain (int argc, wchar_t * argv[], wchar_t * []) { +#else +// Standard C/C++ main entry point +int main (int argc, char * argv[]) { +#endif + + return Catch::Session().run( argc, argv ); +} + +#else // __OBJC__ + +// Objective-C entry point +int main (int argc, char * const argv[]) { +#if !CATCH_ARC_ENABLED + NSAutoreleasePool * pool = [[NSAutoreleasePool alloc] init]; +#endif + + Catch::registerTestMethods(); + int result = Catch::Session().run( argc, (char**)argv ); + +#if !CATCH_ARC_ENABLED + [pool drain]; +#endif + + return result; +} + +#endif // __OBJC__ + +// end catch_default_main.hpp +#endif + +#if !defined(CATCH_CONFIG_IMPL_ONLY) + +#ifdef CLARA_CONFIG_MAIN_NOT_DEFINED +# undef CLARA_CONFIG_MAIN +#endif + +#if !defined(CATCH_CONFIG_DISABLE) +////// +// If this config identifier is defined then all CATCH macros are prefixed with CATCH_ +#ifdef CATCH_CONFIG_PREFIX_ALL + +#define CATCH_REQUIRE( ... ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define CATCH_REQUIRE_FALSE( ... ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) + +#define CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr ) +#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CATCH_REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CATCH_REQUIRE_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::Normal, matcher, expr ) +#endif// CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_REQUIRE_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CATCH_REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, __VA_ARGS__ ) + +#define CATCH_CHECK( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECK_FALSE( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) +#define CATCH_CHECKED_IF( ... ) INTERNAL_CATCH_IF( "CATCH_CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECKED_ELSE( ... ) INTERNAL_CATCH_ELSE( "CATCH_CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECK_NOFAIL( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, __VA_ARGS__ ) + +#define CATCH_CHECK_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CATCH_CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CATCH_CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CATCH_CHECK_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_CHECK_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CATCH_CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg ) + +#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define CATCH_INFO( msg ) INTERNAL_CATCH_INFO( "CATCH_INFO", msg ) +#define CATCH_WARN( msg ) INTERNAL_CATCH_MSG( "CATCH_WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg ) +#define CATCH_CAPTURE( ... ) INTERNAL_CATCH_CAPTURE( INTERNAL_CATCH_UNIQUE_NAME(capturer), "CATCH_CAPTURE",__VA_ARGS__ ) + +#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) +#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define CATCH_METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) +#define CATCH_REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ ) +#define CATCH_SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) +#define CATCH_DYNAMIC_SECTION( ... ) INTERNAL_CATCH_DYNAMIC_SECTION( __VA_ARGS__ ) +#define CATCH_FAIL( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define CATCH_FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_SUCCEED( ... ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#endif + +#if !defined(CATCH_CONFIG_RUNTIME_STATIC_REQUIRE) +#define CATCH_STATIC_REQUIRE( ... ) static_assert( __VA_ARGS__ , #__VA_ARGS__ ); CATCH_SUCCEED( #__VA_ARGS__ ) +#define CATCH_STATIC_REQUIRE_FALSE( ... ) static_assert( !(__VA_ARGS__), "!(" #__VA_ARGS__ ")" ); CATCH_SUCCEED( #__VA_ARGS__ ) +#else +#define CATCH_STATIC_REQUIRE( ... ) CATCH_REQUIRE( __VA_ARGS__ ) +#define CATCH_STATIC_REQUIRE_FALSE( ... ) CATCH_REQUIRE_FALSE( __VA_ARGS__ ) +#endif + +// "BDD-style" convenience wrappers +#define CATCH_SCENARIO( ... ) CATCH_TEST_CASE( "Scenario: " __VA_ARGS__ ) +#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ ) +#define CATCH_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Given: " << desc ) +#define CATCH_AND_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( "And given: " << desc ) +#define CATCH_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " When: " << desc ) +#define CATCH_AND_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And when: " << desc ) +#define CATCH_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Then: " << desc ) +#define CATCH_AND_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And: " << desc ) + +// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required +#else + +#define REQUIRE( ... ) INTERNAL_CATCH_TEST( "REQUIRE", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define REQUIRE_FALSE( ... ) INTERNAL_CATCH_TEST( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) + +#define REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr ) +#define REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "REQUIRE_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::Normal, matcher, expr ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define REQUIRE_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, __VA_ARGS__ ) + +#define CHECK( ... ) INTERNAL_CATCH_TEST( "CHECK", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECK_FALSE( ... ) INTERNAL_CATCH_TEST( "CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) +#define CHECKED_IF( ... ) INTERNAL_CATCH_IF( "CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECKED_ELSE( ... ) INTERNAL_CATCH_ELSE( "CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECK_NOFAIL( ... ) INTERNAL_CATCH_TEST( "CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, __VA_ARGS__ ) + +#define CHECK_THROWS( ... ) INTERNAL_CATCH_THROWS( "CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CHECK_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CHECK_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg ) + +#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg ) +#define UNSCOPED_INFO( msg ) INTERNAL_CATCH_UNSCOPED_INFO( "UNSCOPED_INFO", msg ) +#define WARN( msg ) INTERNAL_CATCH_MSG( "WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg ) +#define CAPTURE( ... ) INTERNAL_CATCH_CAPTURE( INTERNAL_CATCH_UNIQUE_NAME(capturer), "CAPTURE",__VA_ARGS__ ) + +#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) +#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) +#define REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ ) +#define SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) +#define DYNAMIC_SECTION( ... ) INTERNAL_CATCH_DYNAMIC_SECTION( __VA_ARGS__ ) +#define FAIL( ... ) INTERNAL_CATCH_MSG( "FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define SUCCEED( ... ) INTERNAL_CATCH_MSG( "SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#endif + +#if !defined(CATCH_CONFIG_RUNTIME_STATIC_REQUIRE) +#define STATIC_REQUIRE( ... ) static_assert( __VA_ARGS__, #__VA_ARGS__ ); SUCCEED( #__VA_ARGS__ ) +#define STATIC_REQUIRE_FALSE( ... ) static_assert( !(__VA_ARGS__), "!(" #__VA_ARGS__ ")" ); SUCCEED( "!(" #__VA_ARGS__ ")" ) +#else +#define STATIC_REQUIRE( ... ) REQUIRE( __VA_ARGS__ ) +#define STATIC_REQUIRE_FALSE( ... ) REQUIRE_FALSE( __VA_ARGS__ ) +#endif + +#endif + +#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) + +// "BDD-style" convenience wrappers +#define SCENARIO( ... ) TEST_CASE( "Scenario: " __VA_ARGS__ ) +#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ ) + +#define GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Given: " << desc ) +#define AND_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( "And given: " << desc ) +#define WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " When: " << desc ) +#define AND_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And when: " << desc ) +#define THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Then: " << desc ) +#define AND_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And: " << desc ) + +using Catch::Detail::Approx; + +#else // CATCH_CONFIG_DISABLE + +////// +// If this config identifier is defined then all CATCH macros are prefixed with CATCH_ +#ifdef CATCH_CONFIG_PREFIX_ALL + +#define CATCH_REQUIRE( ... ) (void)(0) +#define CATCH_REQUIRE_FALSE( ... ) (void)(0) + +#define CATCH_REQUIRE_THROWS( ... ) (void)(0) +#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) (void)(0) +#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif// CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_REQUIRE_NOTHROW( ... ) (void)(0) + +#define CATCH_CHECK( ... ) (void)(0) +#define CATCH_CHECK_FALSE( ... ) (void)(0) +#define CATCH_CHECKED_IF( ... ) if (__VA_ARGS__) +#define CATCH_CHECKED_ELSE( ... ) if (!(__VA_ARGS__)) +#define CATCH_CHECK_NOFAIL( ... ) (void)(0) + +#define CATCH_CHECK_THROWS( ... ) (void)(0) +#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) (void)(0) +#define CATCH_CHECK_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_CHECK_NOTHROW( ... ) (void)(0) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THAT( arg, matcher ) (void)(0) + +#define CATCH_REQUIRE_THAT( arg, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define CATCH_INFO( msg ) (void)(0) +#define CATCH_WARN( msg ) (void)(0) +#define CATCH_CAPTURE( msg ) (void)(0) + +#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_METHOD_AS_TEST_CASE( method, ... ) +#define CATCH_REGISTER_TEST_CASE( Function, ... ) (void)(0) +#define CATCH_SECTION( ... ) +#define CATCH_DYNAMIC_SECTION( ... ) +#define CATCH_FAIL( ... ) (void)(0) +#define CATCH_FAIL_CHECK( ... ) (void)(0) +#define CATCH_SUCCEED( ... ) (void)(0) + +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#endif + +// "BDD-style" convenience wrappers +#define CATCH_SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className ) +#define CATCH_GIVEN( desc ) +#define CATCH_AND_GIVEN( desc ) +#define CATCH_WHEN( desc ) +#define CATCH_AND_WHEN( desc ) +#define CATCH_THEN( desc ) +#define CATCH_AND_THEN( desc ) + +#define CATCH_STATIC_REQUIRE( ... ) (void)(0) +#define CATCH_STATIC_REQUIRE_FALSE( ... ) (void)(0) + +// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required +#else + +#define REQUIRE( ... ) (void)(0) +#define REQUIRE_FALSE( ... ) (void)(0) + +#define REQUIRE_THROWS( ... ) (void)(0) +#define REQUIRE_THROWS_AS( expr, exceptionType ) (void)(0) +#define REQUIRE_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define REQUIRE_NOTHROW( ... ) (void)(0) + +#define CHECK( ... ) (void)(0) +#define CHECK_FALSE( ... ) (void)(0) +#define CHECKED_IF( ... ) if (__VA_ARGS__) +#define CHECKED_ELSE( ... ) if (!(__VA_ARGS__)) +#define CHECK_NOFAIL( ... ) (void)(0) + +#define CHECK_THROWS( ... ) (void)(0) +#define CHECK_THROWS_AS( expr, exceptionType ) (void)(0) +#define CHECK_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CHECK_NOTHROW( ... ) (void)(0) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THAT( arg, matcher ) (void)(0) + +#define REQUIRE_THAT( arg, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define INFO( msg ) (void)(0) +#define WARN( msg ) (void)(0) +#define CAPTURE( msg ) (void)(0) + +#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define METHOD_AS_TEST_CASE( method, ... ) +#define REGISTER_TEST_CASE( Function, ... ) (void)(0) +#define SECTION( ... ) +#define DYNAMIC_SECTION( ... ) +#define FAIL( ... ) (void)(0) +#define FAIL_CHECK( ... ) (void)(0) +#define SUCCEED( ... ) (void)(0) +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#endif + +#define STATIC_REQUIRE( ... ) (void)(0) +#define STATIC_REQUIRE_FALSE( ... ) (void)(0) + +#endif + +#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature ) + +// "BDD-style" convenience wrappers +#define SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ) ) +#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className ) + +#define GIVEN( desc ) +#define AND_GIVEN( desc ) +#define WHEN( desc ) +#define AND_WHEN( desc ) +#define THEN( desc ) +#define AND_THEN( desc ) + +using Catch::Detail::Approx; + +#endif + +#endif // ! CATCH_CONFIG_IMPL_ONLY + +// start catch_reenable_warnings.h + + +#ifdef __clang__ +# ifdef __ICC // icpc defines the __clang__ macro +# pragma warning(pop) +# else +# pragma clang diagnostic pop +# endif +#elif defined __GNUC__ +# pragma GCC diagnostic pop +#endif + +// end catch_reenable_warnings.h +// end catch.hpp +#endif // TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED + diff --git a/third_party/intgemm/test/add127_test.cc b/third_party/intgemm/test/add127_test.cc new file mode 100644 index 0000000000..c31732c56e --- /dev/null +++ b/third_party/intgemm/test/add127_test.cc @@ -0,0 +1,492 @@ +#include "test.h" + +namespace intgemm { +namespace { + +void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols) { + for (Index r = 0; r<rows; r++) { + for (Index c = 0; c<cols; c++) { + int a = int(output_old[rows*c + r]); + int b = int(output_new[rows*c + r]); + INFO("Inaccurate at row: " << r << " column " << c << ' ' + << a << ' ' << b); + CHECK(a+127 == b); + } + } +} + +template <class Routine> void TestPrepareA(Index rows, Index cols) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-2, 2); + // Create array. + AlignedVector<float> inputA(rows * cols); + for (auto& it : inputA) { + it = dist(gen); + } + AlignedVector<int8_t> oldA(rows * cols); + AlignedVector<uint8_t> newA(rows * cols); + float quant_mult = 64; //From example + Routine::PrepareA(inputA.begin(), oldA.begin(), quant_mult, rows, cols); + Routine::PrepareA(inputA.begin(), newA.begin(), quant_mult, rows, cols); + CompareAs(oldA.begin(), newA.begin(), rows, cols); +} + +template <class Routine> void TestPrepareBias(Index rows, Index cols) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-30.0, 30.0); + // Create array. + AlignedVector<float> inputB(rows * cols); + for (auto& it : inputB) { + it = dist(gen); + } + + float alpha = 25; + float quant_mult = 127/alpha; + + AlignedVector<int8_t> B_prep(inputB.size()); + AlignedVector<int8_t> B_quant(inputB.size()); + Routine::PrepareB(inputB.begin(), B_prep.begin(), quant_mult, rows, cols); + Routine::Quantize(inputB.begin(), B_quant.begin(), quant_mult, static_cast<intgemm::Index>(inputB.size())); + + + AlignedVector<float> inputBias(cols); + AlignedVector<float> goldBias(cols); + + for (auto& it : goldBias) { + it = dist(gen); + } + int i = 0; + for (auto& it : inputBias) { + it = goldBias[i]; + i++; + } + + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); + + Routine::PrepareBias(B_prep.begin(), rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); + + int A_rows = 1; + AlignedVector<int8_t> A_prep2(A_rows*rows); + for (auto& it : A_prep2) { + it =1; + } + //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin())); + //CompareEps(goldBias.begin(), inputBias.begin(), cols, 0.0001f); + AlignedVector<float> slowint_C(cols); + references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult_forprep + goldBias[info.col_idx]; + }); + CompareEps(slowint_C.begin(), inputBias.begin(), cols, 0.0001f); +} + +template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float alpha = 2.0f; + float quant_mult = 127.0f / alpha; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + /*REFERENCE MULTIPLICATION + * + * + */ + AlignedVector<int8_t> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Taking the original A_preparation which means A would be int8_t + AlignedVector<int8_t> A_prep2(A.size()); + Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width); + references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + bias[info.col_idx]; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return static_cast<float>(sum) + bias[info.col_idx]; + }); + + /*ACTUAL MULTIPLICATION + * + */ + float unquant_mult_forprep = (-1.0f)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + //Routine::PrepareBias(B.begin(), bias.begin(), alpha, width, B_cols); + Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = 0; + } + + float alpha = 2.0f; + float quant_mult = 127.0f / alpha; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> A_prep_old(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + /* + * Reference non shift multiplication instead of slowint + */ + AlignedVector<float> slowint_C(test_C.size()); + Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return static_cast<float>(sum) + bias[info.col_idx]; + }); + + /* + * Multiply8 shift multiplication + */ + float unquant_mult_forprep = (-1.0f)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = 0; + } + + float alpha = 2.0f; + float quant_mult = 127.0f / alpha; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> A_prep_old(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + /* + * Reference float multiplication + */ + AlignedVector<int8_t> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Taking the original A_preparation which means A would be int8_t + // references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + // return sum * unquant_mult + bias[info.col_idx]; + // }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return static_cast<float>(sum) + bias[info.col_idx]; + }); + /* + * Multiply8 shift multiplication + */ + //First prepare SlowInteger Bias: + AlignedVector<int8_t> A_prep2(1*width); + for (auto& it : A_prep2) { + it = 1; + } + AlignedVector<float> ShiftedBias(B_cols); + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + references::Multiply(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), 1, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult_forprep + bias[info.col_idx]; + }); + + + //Now prepare Fast integer Bias + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + // Reference INT VERSION HERE with ADD127 + // Taking the original A_preparation which means A would be int8_t + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + ShiftedBias[info.col_idx]; + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + + +// Bias +TEST_CASE("PrepareBias SSSE3", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestPrepareBias<SSSE3::Kernels8>(256,256); + TestPrepareBias<SSSE3::Kernels8>(2048,256); + TestPrepareBias<SSSE3::Kernels8>(512,512); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("PrepareBias AVX2", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestPrepareBias<AVX2::Kernels8>(256,256); + TestPrepareBias<AVX2::Kernels8>(2048,256); + TestPrepareBias<AVX2::Kernels8>(512,512); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("PrepareBias AVX512F", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestPrepareBias<AVX512BW::Kernels8>(256,256); + TestPrepareBias<AVX512BW::Kernels8>(2048,256); + TestPrepareBias<AVX512BW::Kernels8>(512,512); +} +#endif + +//A +TEST_CASE("PrepareA SSSE3", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestPrepareA<SSSE3::Kernels8>(64,64); + TestPrepareA<SSSE3::Kernels8>(256,256); + TestPrepareA<SSSE3::Kernels8>(512,512); + TestPrepareA<SSSE3::Kernels8>(2048,256); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("PrepareA AVX2", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestPrepareA<AVX2::Kernels8>(64,64); + TestPrepareA<AVX2::Kernels8>(256,256); + TestPrepareA<AVX2::Kernels8>(512,512); + TestPrepareA<AVX2::Kernels8>(2048,256); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("PrepareA AVX512F", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestPrepareA<AVX512BW::Kernels8>(64,64); + TestPrepareA<AVX512BW::Kernels8>(256,256); + TestPrepareA<AVX512BW::Kernels8>(512,512); + TestPrepareA<AVX512BW::Kernels8>(2048,256); +} +#endif + +// Multiply + +TEST_CASE ("Multiply SSSE3 8bit Shift with bias", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBiasNew<SSSE3::Kernels8>(1, 64, 8, 0.11f, 0.1f, 0.06f, 0.05f); + TestMultiplyBiasNew<SSSE3::Kernels8>(8, 256, 256, 0.45f, 0.54f, 0.17f, 0.16f); + TestMultiplyBiasNew<SSSE3::Kernels8>(8, 2048, 256, 1.7f, 1.7f, 0.46f, 0.43f); + TestMultiplyBiasNew<SSSE3::Kernels8>(320, 256, 256, 0.56f, 0.64f, 0.16f, 0.15f); + TestMultiplyBiasNew<SSSE3::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f); + TestMultiplyBiasNew<SSSE3::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f); + TestMultiplyBiasNew<SSSE3::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE ("Multiply AVX2 8bit Shift with bias", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasNew<AVX2::Kernels8>(1, 64, 8, 0.11f, 0.11f, 0.06f, 0.05f); + TestMultiplyBiasNew<AVX2::Kernels8>(8, 256, 256, 0.49f, 0.54f, 0.17f, 0.16f); + TestMultiplyBiasNew<AVX2::Kernels8>(8, 2048, 256, 1.57f, 1.66f, 0.46f, 0.46f); + TestMultiplyBiasNew<AVX2::Kernels8>(320, 256, 256, 0.49f, 0.64f, 0.16f, 0.15f); + TestMultiplyBiasNew<AVX2::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f); + TestMultiplyBiasNew<AVX2::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f); + TestMultiplyBiasNew<AVX2::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasNew<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f); + TestMultiplyBiasNew<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512BW::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f); + TestMultiplyBiasNew<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + TEST_CASE ("Multiply AVX512VNNI 8bit Shift with bias", "[Add127]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f); + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f); + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f); + TestMultiplyBiasNew<AVX512VNNI::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f); + } +#endif + +//Multiply old vs new +TEST_CASE ("Multiply SSSE3 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyShiftNonShift<SSSE3::Kernels8>(1, 64, 8, 0.00001f, 0.1f, 0.06f, 0.00001f); + TestMultiplyShiftNonShift<SSSE3::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f); + TestMultiplyShiftNonShift<SSSE3::Kernels8>(8, 2048, 256, 17.9f, 1.7f, 0.46f, 4.2f); //Big difference here because the non-shift version is very bad + TestMultiplyShiftNonShift<SSSE3::Kernels8>(320, 256, 256, 1.2f, 0.64f, 0.16f, 0.006f); + TestMultiplyShiftNonShift<SSSE3::Kernels8>(472, 256, 256, 1.1f, 0.62f, 0.17f, 0.006f); + TestMultiplyShiftNonShift<SSSE3::Kernels8>(248, 256, 256, 0.9f, 0.64f, 0.16f, 0.007f); + TestMultiplyShiftNonShift<SSSE3::Kernels8>(200, 256, 256, 1, 0.74f, 0.17f, 0.006f); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE ("Multiply AVX2 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyShiftNonShift<AVX2::Kernels8>(1, 64, 8, 0.00001f, 0.11f, 0.06f, 0.00001f); + TestMultiplyShiftNonShift<AVX2::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f); + TestMultiplyShiftNonShift<AVX2::Kernels8>(8, 2048, 256, 9.4f, 1.66f, 0.46f, 1.67f); //Big difference here because the non-shift version is very bad + TestMultiplyShiftNonShift<AVX2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f); + TestMultiplyShiftNonShift<AVX2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f); + TestMultiplyShiftNonShift<AVX2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f); + TestMultiplyShiftNonShift<AVX2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE ("Multiply AVX512F 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f); + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f); + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(8, 2048, 256, 3.51f, 0.61f, 0.17f, 0.3f); + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f); + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f); + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f); + TestMultiplyShiftNonShift<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(1, 64, 8, 0.00001f, 0.05f, 0.03f, 0.00001f); + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(8, 256, 256, 0.00001f, 0.22f, 0.06f, 0.00001f); + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f); + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(320, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f); + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(472, 256, 256, 0.00001f, 0.33f, 0.06f, 0.00001f); + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(248, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f); + TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(200, 256, 256, 0.00001f, 0.28f, 0.06f, 0.00001f); + } +#endif + +//Multiply Shift vs int shift implementation +TEST_CASE ("Multiply SSSE3 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyShiftInt<SSSE3::Kernels8>(1, 64, 8, 0.0001f, 0.1f, 0.06f, 0.0001f); + TestMultiplyShiftInt<SSSE3::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f); + TestMultiplyShiftInt<SSSE3::Kernels8>(8, 2048, 256, 0.0001f, 1.7f, 0.46f, 0.0001f); + TestMultiplyShiftInt<SSSE3::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f); + TestMultiplyShiftInt<SSSE3::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f); + TestMultiplyShiftInt<SSSE3::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f); + TestMultiplyShiftInt<SSSE3::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE ("Multiply AVX2 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyShiftInt<AVX2::Kernels8>(1, 64, 8, 0.0001f, 0.11f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX2::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f); + TestMultiplyShiftInt<AVX2::Kernels8>(8, 2048, 256, 0.0001f, 1.66f, 0.46f, 0.0001f); + TestMultiplyShiftInt<AVX2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f); + TestMultiplyShiftInt<AVX2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f); + TestMultiplyShiftInt<AVX2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f); + TestMultiplyShiftInt<AVX2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE ("Multiply AVX512F 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyShiftInt<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f); + TestMultiplyShiftInt<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512BW::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f); + TestMultiplyShiftInt<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f); + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f); + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f); + TestMultiplyShiftInt<AVX512VNNI::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f); +} +#endif + +} // namespace +} // namespace intgemm diff --git a/third_party/intgemm/test/kernels/add_bias_test.cc b/third_party/intgemm/test/kernels/add_bias_test.cc new file mode 100644 index 0000000000..b9e5fd95f5 --- /dev/null +++ b/third_party/intgemm/test/kernels/add_bias_test.cc @@ -0,0 +1,66 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_, typename ElemType_> +void kernel_add_bias_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, ElemType_>; + constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_); + + AlignedVector<ElemType_> input(VECTOR_LENGTH); + AlignedVector<ElemType_> bias(VECTOR_LENGTH); + AlignedVector<ElemType_> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), static_cast<ElemType_>(0)); + std::fill(bias.begin(), bias.end(), static_cast<ElemType_>(100)); + + *output.template as<vec_t>() = kernels::add_bias(*input.template as<vec_t>(), bias.begin(), 0); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == ElemType_(100 + i)); +} + +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int16_t>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("add_bias/int8 SSE2") { return kernel_add_bias_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("add_bias/int16 SSE2") { return kernel_add_bias_test<CPUType::SSE2, int16_t>(); } +KERNEL_TEST_CASE("add_bias/int SSE2") { return kernel_add_bias_test<CPUType::SSE2, int>(); } +KERNEL_TEST_CASE("add_bias/float SSE2") { return kernel_add_bias_test<CPUType::SSE2, float>(); } +KERNEL_TEST_CASE("add_bias/double SSE2") { return kernel_add_bias_test<CPUType::SSE2, double>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int16_t>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("add_bias/int8 AVX2") { return kernel_add_bias_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("add_bias/int16 AVX2") { return kernel_add_bias_test<CPUType::AVX2, int16_t>(); } +KERNEL_TEST_CASE("add_bias/int AVX2") { return kernel_add_bias_test<CPUType::AVX2, int>(); } +KERNEL_TEST_CASE("add_bias/float AVX2") { return kernel_add_bias_test<CPUType::AVX2, float>(); } +KERNEL_TEST_CASE("add_bias/double AVX2") { return kernel_add_bias_test<CPUType::AVX2, double>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int16_t>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("add_bias/int8 AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("add_bias/int16 AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int16_t>(); } +KERNEL_TEST_CASE("add_bias/int AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int>(); } +KERNEL_TEST_CASE("add_bias/float AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, float>(); } +KERNEL_TEST_CASE("add_bias/double AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, double>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/bitwise_not_test.cc b/third_party/intgemm/test/kernels/bitwise_not_test.cc new file mode 100644 index 0000000000..6c28c9554d --- /dev/null +++ b/third_party/intgemm/test/kernels/bitwise_not_test.cc @@ -0,0 +1,41 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstdlib> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_bitwise_not_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, int>; + constexpr static std::size_t VECTOR_LENGTH = sizeof(vec_t) / sizeof(int); + + AlignedVector<int> input(VECTOR_LENGTH); + AlignedVector<int> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), 0); + + *output.template as<vec_t>() = kernels::bitwise_not(*input.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == ~input[i]); +} + +template INTGEMM_SSE2 void kernel_bitwise_not_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("bitwise_not SSE2") { return kernel_bitwise_not_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_bitwise_not_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("bitwise_not AVX2") { return kernel_bitwise_not_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_bitwise_not_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("bitwise_not AVX512BW") { return kernel_bitwise_not_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/downcast_test.cc b/third_party/intgemm/test/kernels/downcast_test.cc new file mode 100644 index 0000000000..0f2ccd0edc --- /dev/null +++ b/third_party/intgemm/test/kernels/downcast_test.cc @@ -0,0 +1,107 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstddef> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_downcast32to8_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + constexpr int LENGTH = sizeof(vi) / sizeof(int8_t); + + AlignedVector<int32_t> input(LENGTH); + AlignedVector<int8_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int32_t>(-LENGTH / 2)); + + *output.template as<vi>() = kernels::downcast32to8( + input.template as<vi>()[0], input.template as<vi>()[1], + input.template as<vi>()[2], input.template as<vi>()[3]); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int8_t(input[i])); +} + +template INTGEMM_SSE2 void kernel_downcast32to8_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("downcast32to8 SSE2") { return kernel_downcast32to8_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_downcast32to8_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("downcast32to8 AVX2") { return kernel_downcast32to8_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_downcast32to8_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("downcast32to8 AVX512BW") { return kernel_downcast32to8_test<CPUType::AVX512BW>(); } +#endif + +template <CPUType CPUType_> +void kernel_downcast32to16_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + constexpr int LENGTH = sizeof(vi) / sizeof(int16_t); + + AlignedVector<int32_t> input(LENGTH); + AlignedVector<int16_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int32_t>(-LENGTH / 2)); + + *output.template as<vi>() = kernels::downcast32to16( + input.template as<vi>()[0], input.template as<vi>()[1]); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int16_t(input[i])); +} + +template INTGEMM_SSE2 void kernel_downcast32to16_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("downcast32to16 SSE2") { return kernel_downcast32to16_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_downcast32to16_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("downcast32to16 AVX2") { return kernel_downcast32to16_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_downcast32to16_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("downcast32to16 AVX512BW") { return kernel_downcast32to16_test<CPUType::AVX512BW>(); } +#endif + +template <CPUType CPUType_> +void kernel_downcast16to8_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + constexpr int LENGTH = sizeof(vi) / sizeof(int8_t); + + AlignedVector<int16_t> input(LENGTH); + AlignedVector<int8_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int16_t>(-LENGTH / 2)); + + *output.template as<vi>() = kernels::downcast16to8( + input.template as<vi>()[0], input.template as<vi>()[1]); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int8_t(input[i])); +} + +template INTGEMM_SSE2 void kernel_downcast16to8_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("downcast16to8 SSE2") { return kernel_downcast16to8_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_downcast16to8_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("downcast16to8 AVX2") { return kernel_downcast16to8_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_downcast16to8_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("downcast16to8 AVX512BW") { return kernel_downcast16to8_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/exp_test.cc b/third_party/intgemm/test/kernels/exp_test.cc new file mode 100644 index 0000000000..9f535f25b9 --- /dev/null +++ b/third_party/intgemm/test/kernels/exp_test.cc @@ -0,0 +1,38 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstddef> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_exp_approx_taylor_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, float>; + constexpr static std::size_t VECTOR_LENGTH = sizeof(vec_t) / sizeof(float); + + AlignedVector<float> input(VECTOR_LENGTH); + AlignedVector<float> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), -static_cast<float>(VECTOR_LENGTH / 2)); + + *output.template as<vec_t>() = kernels::exp_approx_taylor(*input.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK_EPS(output[i], exp(input[i]), 0.001f); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_exp_approx_taylor_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("exp_approx_taylor AVX2") { return kernel_exp_approx_taylor_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_exp_approx_taylor_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("exp_approx_taylor AVX512BW") { return kernel_exp_approx_taylor_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/floor_test.cc b/third_party/intgemm/test/kernels/floor_test.cc new file mode 100644 index 0000000000..9b7a214134 --- /dev/null +++ b/third_party/intgemm/test/kernels/floor_test.cc @@ -0,0 +1,41 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstddef> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_floor_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, float>; + constexpr static std::size_t VECTOR_LENGTH = sizeof(vec_t) / sizeof(float); + + AlignedVector<float> input(VECTOR_LENGTH); + AlignedVector<float> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), -static_cast<float>(VECTOR_LENGTH / 2)); + + *output.template as<vec_t>() = kernels::floor(*input.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == std::floor(input[i])); +} + +template INTGEMM_SSE2 void kernel_floor_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("floor SSE2") { return kernel_floor_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_floor_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("floor AVX2") { return kernel_floor_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_floor_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("floor AVX512BW") { return kernel_floor_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/multiply_test.cc b/third_party/intgemm/test/kernels/multiply_test.cc new file mode 100644 index 0000000000..fc1a51eeb1 --- /dev/null +++ b/third_party/intgemm/test/kernels/multiply_test.cc @@ -0,0 +1,67 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstdint> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_, typename Type_> +void kernel_multiply_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, Type_>; + constexpr int VECTOR_LENGTH = sizeof(vec_t) / sizeof(Type_); + + AlignedVector<Type_> input1(VECTOR_LENGTH); + AlignedVector<Type_> input2(VECTOR_LENGTH); + AlignedVector<Type_> output(VECTOR_LENGTH); + + std::iota(input1.begin(), input1.end(), static_cast<Type_>(-VECTOR_LENGTH / 2)); + std::iota(input2.begin(), input2.end(), static_cast<Type_>(-VECTOR_LENGTH / 3)); + + *output.template as<vec_t>() = kernels::multiply<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == Type_(input1[i] * input2[i])); +} + +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int16_t>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("multiply/int8 SSE2") { return kernel_multiply_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("multiply/int16 SSE2") { return kernel_multiply_test<CPUType::SSE2, int16_t>(); } +KERNEL_TEST_CASE("multiply/int SSE2") { return kernel_multiply_test<CPUType::SSE2, int>(); } +KERNEL_TEST_CASE("multiply/float SSE2") { return kernel_multiply_test<CPUType::SSE2, float>(); } +KERNEL_TEST_CASE("multiply/double SSE2") { return kernel_multiply_test<CPUType::SSE2, double>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int16_t>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("multiply/int8 AVX2") { return kernel_multiply_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("multiply/int16 AVX2") { return kernel_multiply_test<CPUType::AVX2, int16_t>(); } +KERNEL_TEST_CASE("multiply/int AVX2") { return kernel_multiply_test<CPUType::AVX2, int>(); } +KERNEL_TEST_CASE("multiply/float AVX2") { return kernel_multiply_test<CPUType::AVX2, float>(); } +KERNEL_TEST_CASE("multiply/double AVX2") { return kernel_multiply_test<CPUType::AVX2, double>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int16_t>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("multiply/int8 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("multiply/int16 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int16_t>(); } +KERNEL_TEST_CASE("multiply/int AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int>(); } +KERNEL_TEST_CASE("multiply/float AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, float>(); } +KERNEL_TEST_CASE("multiply/double AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, double>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/quantize_test.cc b/third_party/intgemm/test/kernels/quantize_test.cc new file mode 100644 index 0000000000..93280f7e27 --- /dev/null +++ b/third_party/intgemm/test/kernels/quantize_test.cc @@ -0,0 +1,41 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_quantize_test() { + if (kCPU < CPUType_) + return; + + using input_vec_t = vector_t<CPUType_, float>; + using output_vec_t = vector_t<CPUType_, int>; + + AlignedVector<float> input(sizeof(input_vec_t) / sizeof(float)); + AlignedVector<int> output(sizeof(output_vec_t) / sizeof(int)); + + std::iota(input.begin(), input.end(), 0.0f); + auto quant_mult = set1_ps<input_vec_t>(2.f); + + *output.template as<output_vec_t>() = kernels::quantize(*input.template as<input_vec_t>(), quant_mult); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int(i*2.f)); +} + +template INTGEMM_SSE2 void kernel_quantize_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("quantize SSE2") { return kernel_quantize_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_quantize_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("quantize AVX2") { return kernel_quantize_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_quantize_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("quantize AVX512BW") { return kernel_quantize_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/relu_test.cc b/third_party/intgemm/test/kernels/relu_test.cc new file mode 100644 index 0000000000..8fd30ae25a --- /dev/null +++ b/third_party/intgemm/test/kernels/relu_test.cc @@ -0,0 +1,65 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstdint> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_, typename ElemType_> +void kernel_relu_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, ElemType_>; + constexpr int VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_); + + AlignedVector<ElemType_> input(VECTOR_LENGTH); + AlignedVector<ElemType_> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), static_cast<ElemType_>(-VECTOR_LENGTH / 2)); + + *output.template as<vec_t>() = kernels::relu<ElemType_>(*input.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == (input[i] < 0 ? 0 : input[i])); +} + +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int16_t>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("relu/int8 SSE2") { return kernel_relu_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("relu/int16 SSE2") { return kernel_relu_test<CPUType::SSE2, int16_t>(); } +KERNEL_TEST_CASE("relu/int SSE2") { return kernel_relu_test<CPUType::SSE2, int>(); } +KERNEL_TEST_CASE("relu/float SSE2") { return kernel_relu_test<CPUType::SSE2, float>(); } +KERNEL_TEST_CASE("relu/double SSE2") { return kernel_relu_test<CPUType::SSE2, double>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int16_t>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("relu/int8 AVX2") { return kernel_relu_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("relu/int16 AVX2") { return kernel_relu_test<CPUType::AVX2, int16_t>(); } +KERNEL_TEST_CASE("relu/int AVX2") { return kernel_relu_test<CPUType::AVX2, int>(); } +KERNEL_TEST_CASE("relu/float AVX2") { return kernel_relu_test<CPUType::AVX2, float>(); } +KERNEL_TEST_CASE("relu/double AVX2") { return kernel_relu_test<CPUType::AVX2, double>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int16_t>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("relu/int8 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("relu/int16 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int16_t>(); } +KERNEL_TEST_CASE("relu/int AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int>(); } +KERNEL_TEST_CASE("relu/float AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, float>(); } +KERNEL_TEST_CASE("relu/double AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, double>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/rescale_test.cc b/third_party/intgemm/test/kernels/rescale_test.cc new file mode 100644 index 0000000000..13937eddf0 --- /dev/null +++ b/third_party/intgemm/test/kernels/rescale_test.cc @@ -0,0 +1,43 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstdint> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_rescale_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + using vf = vector_t<CPUType_, float>; + constexpr int LENGTH = sizeof(vi) / sizeof(int); + + AlignedVector<int32_t> input(LENGTH); + AlignedVector<int32_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int32_t>(-LENGTH / 2)); + float scale = 2; + + *output.template as<vi>() = kernels::rescale(*input.template as<vi>(), intgemm::set1_ps<vf>(scale)); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == std::round(input[i] * scale)); +} + +template INTGEMM_SSE2 void kernel_rescale_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("rescale SSE2") { return kernel_rescale_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_rescale_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("rescale AVX2") { return kernel_rescale_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_rescale_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("rescale AVX512BW") { return kernel_rescale_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/sigmoid_test.cc b/third_party/intgemm/test/kernels/sigmoid_test.cc new file mode 100644 index 0000000000..7827593b0c --- /dev/null +++ b/third_party/intgemm/test/kernels/sigmoid_test.cc @@ -0,0 +1,45 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstddef> +#include <numeric> + +namespace intgemm { + +float sigmoid_ref(float x) { + if (x < 0) + return exp(x) / (1 + exp(x)); + else + return 1 / (1 + exp(-x)); +} + +template <CPUType CPUType_> +void kernel_sigmoid_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, float>; + constexpr static std::size_t VECTOR_LENGTH = sizeof(vec_t) / sizeof(float); + + AlignedVector<float> input(VECTOR_LENGTH); + AlignedVector<float> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), -static_cast<float>(VECTOR_LENGTH / 2)); + + *output.template as<vec_t>() = kernels::sigmoid(*input.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK_EPS(output[i], sigmoid_ref(input[i]), 0.001f); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_sigmoid_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("sigmoid AVX2") { return kernel_sigmoid_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_sigmoid_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("sigmoid AVX512BW") { return kernel_sigmoid_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/tanh_test.cc b/third_party/intgemm/test/kernels/tanh_test.cc new file mode 100644 index 0000000000..1d00042b8d --- /dev/null +++ b/third_party/intgemm/test/kernels/tanh_test.cc @@ -0,0 +1,38 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstddef> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_tanh_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, float>; + constexpr static std::size_t VECTOR_LENGTH = sizeof(vec_t) / sizeof(float); + + AlignedVector<float> input(VECTOR_LENGTH); + AlignedVector<float> output(VECTOR_LENGTH); + + std::generate(input.begin(), input.end(), [] () { static int n = -int(VECTOR_LENGTH / 2); return n++ / float(VECTOR_LENGTH / 2); }); + + *output.template as<vec_t>() = kernels::tanh(*input.template as<vec_t>()); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK_EPS(output[i], tanh(input[i]), 0.001f); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_tanh_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("tanh AVX2") { return kernel_tanh_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_tanh_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("tanh AVX512BW") { return kernel_tanh_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/unquantize_test.cc b/third_party/intgemm/test/kernels/unquantize_test.cc new file mode 100644 index 0000000000..edfafa5b2e --- /dev/null +++ b/third_party/intgemm/test/kernels/unquantize_test.cc @@ -0,0 +1,41 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_unquantize_test() { + if (kCPU < CPUType_) + return; + + using input_vec_t = vector_t<CPUType_, int>; + using output_vec_t = vector_t<CPUType_, float>; + + AlignedVector<int> input(sizeof(input_vec_t) / sizeof(int)); + AlignedVector<float> output(sizeof(output_vec_t) / sizeof(float)); + + std::iota(input.begin(), input.end(), 0); + auto unquant_mult = set1_ps<output_vec_t>(0.5f); + + *output.template as<output_vec_t>() = kernels::unquantize(*input.template as<input_vec_t>(), unquant_mult); + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == i * 0.5f); +} + +template INTGEMM_SSE2 void kernel_unquantize_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("unquantize SSE2") { return kernel_unquantize_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_unquantize_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("unquantize AVX2") { return kernel_unquantize_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_unquantize_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("unquantize AVX512BW") { return kernel_unquantize_test<CPUType::AVX512BW>(); } +#endif + +} diff --git a/third_party/intgemm/test/kernels/upcast_test.cc b/third_party/intgemm/test/kernels/upcast_test.cc new file mode 100644 index 0000000000..0733922ff0 --- /dev/null +++ b/third_party/intgemm/test/kernels/upcast_test.cc @@ -0,0 +1,118 @@ +// This test triggers an internal compiler error in gcc 5. +#if defined(__OPTIMIZE__) || defined(__clang__) || defined(__INTEL_COMPILER) || !defined(__GNUC__) || (__GNUC__ != 5) +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstdint> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_> +void kernel_upcast8to16_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + constexpr int LENGTH = sizeof(vi) / sizeof(int8_t); + + AlignedVector<int8_t> input(LENGTH); + AlignedVector<int16_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int8_t>(-LENGTH / 2)); + + auto result = kernels::upcast8to16(*input.template as<vi>()); + output.template as<vi>()[0] = result.first; + output.template as<vi>()[1] = result.second; + + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int16_t(input[i])); +} + +template INTGEMM_SSE2 void kernel_upcast8to16_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("upcast8to16 SSE2") { return kernel_upcast8to16_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_upcast8to16_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("upcast8to16 AVX2") { return kernel_upcast8to16_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_upcast8to16_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("upcast8to16 AVX512BW") { return kernel_upcast8to16_test<CPUType::AVX512BW>(); } +#endif + +template <CPUType CPUType_> +void kernel_upcast16to32_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + constexpr int LENGTH = sizeof(vi) / sizeof(int16_t); + + AlignedVector<int16_t> input(LENGTH); + AlignedVector<int32_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int16_t>(-LENGTH / 2)); + + auto result = kernels::upcast16to32(*input.template as<vi>()); + output.template as<vi>()[0] = result.first; + output.template as<vi>()[1] = result.second; + + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int32_t(input[i])); +} + +template INTGEMM_SSE2 void kernel_upcast16to32_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("upcast16to32 SSE2") { return kernel_upcast16to32_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_upcast16to32_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("upcast16to32 AVX2") { return kernel_upcast16to32_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_upcast16to32_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("upcast16to32 AVX512BW") { return kernel_upcast16to32_test<CPUType::AVX512BW>(); } +#endif + + +template <CPUType CPUType_> +void kernel_upcast8to32_test() { + if (kCPU < CPUType_) + return; + + using vi = vector_t<CPUType_, int>; + constexpr int LENGTH = sizeof(vi) / sizeof(int8_t); + + AlignedVector<int8_t> input(LENGTH); + AlignedVector<int32_t> output(LENGTH); + + std::iota(input.begin(), input.end(), static_cast<int8_t>(-LENGTH / 2)); + + auto result = kernels::upcast8to32(*input.template as<vi>()); + output.template as<vi>()[0] = result.first; + output.template as<vi>()[1] = result.second; + output.template as<vi>()[2] = result.third; + output.template as<vi>()[3] = result.fourth; + + for (std::size_t i = 0; i < output.size(); ++i) + CHECK(output[i] == int32_t(input[i])); +} + +template INTGEMM_SSE2 void kernel_upcast8to32_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("upcast8to32 SSE2") { return kernel_upcast8to32_test<CPUType::SSE2>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_upcast8to32_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("upcast8to32 AVX2") { return kernel_upcast8to32_test<CPUType::AVX2>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_upcast8to32_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("upcast8to32 AVX512BW") { return kernel_upcast8to32_test<CPUType::AVX512BW>(); } +#endif + +} +#endif diff --git a/third_party/intgemm/test/kernels/write_test.cc b/third_party/intgemm/test/kernels/write_test.cc new file mode 100644 index 0000000000..a136a86d3e --- /dev/null +++ b/third_party/intgemm/test/kernels/write_test.cc @@ -0,0 +1,65 @@ +#include "../test.h" +#include "../../intgemm/aligned.h" +#include "../../intgemm/kernels.h" + +#include <cstddef> +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_, typename ElemType_> +void kernel_write_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, ElemType_>; + constexpr static std::size_t VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_); + + AlignedVector<ElemType_> input(VECTOR_LENGTH); + AlignedVector<ElemType_> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), static_cast<ElemType_>(0)); + + kernels::write(*input.template as<vec_t>(), output.begin(), 0); + for (std::size_t i = 0; i < VECTOR_LENGTH; ++i) + CHECK(output[i] == ElemType_(i)); +} + +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int16_t>(); +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int>(); +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("write/int8 SSE2") { return kernel_write_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("write/int16 SSE2") { return kernel_write_test<CPUType::SSE2, int16_t>(); } +KERNEL_TEST_CASE("write/int SSE2") { return kernel_write_test<CPUType::SSE2, int>(); } +KERNEL_TEST_CASE("write/float SSE2") { return kernel_write_test<CPUType::SSE2, float>(); } +KERNEL_TEST_CASE("write/double SSE2") { return kernel_write_test<CPUType::SSE2, double>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int16_t>(); +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int>(); +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("write/int8 AVX2") { return kernel_write_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("write/int16 AVX2") { return kernel_write_test<CPUType::AVX2, int16_t>(); } +KERNEL_TEST_CASE("write/int AVX2") { return kernel_write_test<CPUType::AVX2, int>(); } +KERNEL_TEST_CASE("write/float AVX2") { return kernel_write_test<CPUType::AVX2, float>(); } +KERNEL_TEST_CASE("write/double AVX2") { return kernel_write_test<CPUType::AVX2, double>(); } +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int16_t>(); +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int>(); +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("write/int8 AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("write/int16 AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int16_t>(); } +KERNEL_TEST_CASE("write/int AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int>(); } +KERNEL_TEST_CASE("write/float AVX512BW") { return kernel_write_test<CPUType::AVX512BW, float>(); } +KERNEL_TEST_CASE("write/double AVX512BW") { return kernel_write_test<CPUType::AVX512BW, double>(); } +#endif + +} diff --git a/third_party/intgemm/test/multiply_test.cc b/third_party/intgemm/test/multiply_test.cc new file mode 100644 index 0000000000..f72758fe19 --- /dev/null +++ b/third_party/intgemm/test/multiply_test.cc @@ -0,0 +1,761 @@ +#include "test.h" +#include "../intgemm/aligned.h" +#include "../intgemm/callbacks.h" +#include "../intgemm/interleave.h" +#include "../intgemm/intgemm.h" +#include "../intgemm/multiply.h" +#include "../intgemm/stats.h" + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <iomanip> +#include <iostream> +#include <memory> +#include <numeric> +#include <random> + +namespace intgemm { + +#ifndef __INTEL_COMPILER +INTGEMM_SSE2 +#endif +TEST_CASE("Transpose 16", "[transpose]") { + if (kCPU < CPUType::SSE2) return; + const unsigned N = 8; + AlignedVector<int16_t> input(N * N); + std::iota(input.begin(), input.end(), static_cast<int16_t>(0)); + + AlignedVector<int16_t> ref(N * N); + references::Transpose(input.begin(), ref.begin(), N, N); + + // Overwrite input. + __m128i *t = input.as<__m128i>(); + Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]); + + for (std::size_t i = 0; i < input.size(); ++i) { + CHECK_MESSAGE(ref[i] == input[i], "16-bit transpose failure at: " << i << ": " << ref[i] << " != " << input[i]); + } +} + +#ifndef __INTEL_COMPILER +INTGEMM_SSSE3 +#endif +TEST_CASE("Transpose 8", "[transpose]") { + if (kCPU < CPUType::SSSE3) return; + const unsigned N = 16; + AlignedVector<int8_t> input(N * N); + std::iota(input.begin(), input.end(), static_cast<int8_t>(0)); + + AlignedVector<int8_t> ref(input.size()); + references::Transpose(input.begin(), ref.begin(), N, N); + + // Overwrite input. + __m128i *t = input.as<__m128i>(); + Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]); + + for (std::size_t i = 0; i < input.size(); ++i) { + CHECK_MESSAGE(ref[i] == input[i], "8-bit transpose failure at " << i << ": " << (int16_t)ref[i] << " != " << (int16_t)input[i]); + } +} + +template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-129.0, 129.0); + // Create array. + AlignedVector<float> input(rows * cols); + for (auto& it : input) { + it = dist(gen); + } + + using Integer = typename Routine::Integer; + // Call Prepare + AlignedVector<Integer> test(input.size()); + Routine::PrepareB(input.begin(), test.begin(), 1, rows, cols); + + // Compute reference output. + AlignedVector<Integer> quantized(input.size()); + Routine::Quantize(input.begin(), quantized.begin(), 1, static_cast<Index>(input.size())); + AlignedVector<Integer> reference(input.size()); + // Note this won't work for Int8/Int16 generic routines because tile sizes vary. + references::Rearragement(quantized.begin(), reference.begin(), Routine::kBTileRow, Routine::kBTileCol, rows, cols); + CHECK_MESSAGE(memcmp(reference.begin(), test.begin(), test.size() * sizeof(Integer)) == 0, Routine::kName << " Mismatch:\n" << + "Quantized Input" << '\n' << PrintMatrix(quantized.begin(), rows, cols) << "Reference" << '\n' << + PrintMatrix(reference.begin(), rows, cols) << "Routine" << '\n' << PrintMatrix(test.begin(), rows, cols)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("Prepare AVX512", "[prepare]") { + if (kCPU < CPUType::AVX512BW) return; + TestPrepare<AVX512BW::Kernels8>(64, 8); + TestPrepare<AVX512BW::Kernels8>(256, 32); + TestPrepare<AVX512BW::Kernels16>(64, 8); + TestPrepare<AVX512BW::Kernels16>(256, 32); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("Prepare AVX2", "[prepare]") { + if (kCPU < CPUType::AVX2) return; + TestPrepare<AVX2::Kernels8>(64, 32); + TestPrepare<AVX2::Kernels16>(64, 32); +} +#endif + +TEST_CASE("Prepare SSSE3", "[prepare]") { + if (kCPU < CPUType::SSSE3) return; + TestPrepare<SSSE3::Kernels8>(16, 8); + TestPrepare<SSSE3::Kernels8>(32, 16); + TestPrepare<SSSE3::Kernels8>(32, 32); +} + +TEST_CASE("Prepare SSE2", "[prepare]") { + if (kCPU < CPUType::SSE2) return; + TestPrepare<SSE2::Kernels16>(8, 8); + TestPrepare<SSE2::Kernels16>(32, 32); +} + +template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 16) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-129.0, 129.0); + AlignedVector<float> input(rows * cols); + for (auto& it : input) { + it = dist(gen); + } + using Integer = typename Routine::Integer; + AlignedVector<Integer> prepared(input.size()); + Routine::PrepareB(input.begin(), prepared.begin(), 1, rows, cols); + + const int kSelectCols = 24; + Index select_cols[kSelectCols]; + std::uniform_int_distribution<Index> col_dist(0, cols - 1); + for (auto& it : select_cols) { + it = col_dist(gen); + } + + AlignedVector<Integer> test(rows * kSelectCols); + Routine::SelectColumnsB(prepared.begin(), test.begin(), rows, select_cols, select_cols + kSelectCols); + + // Select columns manually in float space. + AlignedVector<float> selected(rows * kSelectCols); + for (Index r = 0; r < rows; ++r) { + for (int c = 0; c < kSelectCols; ++c) { + assert(c + r * kSelectCols < rows * kSelectCols); + selected[c + r * kSelectCols] = input[select_cols[c] + r * cols]; + } + } + AlignedVector<Integer> ref(rows * kSelectCols); + Routine::PrepareB(selected.begin(), ref.begin(), 1, rows, kSelectCols); + CHECK_MESSAGE(memcmp(ref.begin(), test.begin(), sizeof(Integer) * rows * kSelectCols) == 0, "Reference:\n" << + PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("SelectColumnsB AVX512", "[select]") { + if (kCPU < CPUType::AVX512BW) return; + TestSelectColumnsB<AVX512BW::Kernels8>(); + TestSelectColumnsB<AVX512BW::Kernels16>(256, 256); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("SelectColumnsB AVX2", "[select]") { + if (kCPU < CPUType::AVX2) return; + TestSelectColumnsB<AVX2::Kernels8>(256, 256); + TestSelectColumnsB<AVX2::Kernels16>(256, 256); +} +#endif + +TEST_CASE("SelectColumnsB SSSE3", "[select]") { + if (kCPU < CPUType::SSSE3) return; + TestSelectColumnsB<SSSE3::Kernels8>(); + TestSelectColumnsB<SSSE3::Kernels8>(256, 256); +} + +TEST_CASE("SelectColumnsB SSE2", "[select]") { + if (kCPU < CPUType::SSE2) return; + TestSelectColumnsB<SSE2::Kernels16>(); + TestSelectColumnsB<SSE2::Kernels16>(256, 256); +} + +template <class Register> void TestMax() { + Register r = set1_ps<Register>(-2.0); + for (std::size_t i = 0; i < sizeof(Register) / sizeof(float); ++i) { + Register c = r; + reinterpret_cast<float*>(&c)[i] = -1.0; + CHECK_MESSAGE((MaxFloat32(c) == -1.0), "MaxFloat32 produced " << MaxFloat32(c)); + } +} + +TEST_CASE("Max", "[max]") { + TestMax<__m128>(); +} + +void CompareMaxAbs(const float *begin, const float *end, float test, std::size_t offset) { + float largest = std::fabs(*std::max_element(begin, end)); + float smallest = std::fabs(*std::min_element(begin, end)); + largest = std::max(largest, smallest); + CHECK_MESSAGE(largest == test, "Error: " << largest << " versus " << test << " in length " << (end - begin) << " offset " << offset); +} + +template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute() { + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-8.0, 8.0); + const std::size_t kLengthMax = 65; + AlignedVector<float> test(kLengthMax); + for (std::size_t len = 1; len < kLengthMax; ++len) { + for (std::size_t t = 0; t < len; ++t) { + // Fill with [-8, 8). + for (auto& it : test) { + it = dist(gen); + } + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + test[t] = -32.0; + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + test[t] = 32.0; + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + } + } +} + +TEST_CASE("MaxAbsolute SSE2", "[max]") { + if (kCPU < CPUType::SSE2) return; + TestMaxAbsolute<SSE2::MaxAbsolute>(); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("MaxAbsolute AVX2", "[max]") { + if (kCPU < CPUType::AVX2) return; + TestMaxAbsolute<AVX2::MaxAbsolute>(); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("MaxAbsolute AVX512BW", "[max]") { + if (kCPU < CPUType::AVX512BW) return; + TestMaxAbsolute<AVX512BW::MaxAbsolute>(); +} +#endif + +// Based on https://arxiv.org/abs/1705.01991 + +// Copyright (c) 2017 Microsoft Corporation + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// Compute A*B slowly in floats. + +template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); + // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( + // callbacks::Unquantize(unquant_mult), + // callbacks::Write<float>(test_C.begin()) + // )); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) { + return sum * unquant_mult; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) { + return static_cast<float>(sum); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin())); + // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( + // callbacks::Unquantize(unquant_mult), + // callbacks::Write<float>(test_C.begin()) + // )); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) { + float ret = std::max(0.0f, sum * unquant_mult); + return ret; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) { + return static_cast<float>(std::max(0.0,sum)); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +//Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols +//Require different number of arguments. I don't think the refactoring is worth it. +template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols, + float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + bias[info.col_idx]; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return static_cast<float>(sum) + bias[info.col_idx]; + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin())); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]); + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { + if (kCPU < CPUType::SSE2) return; + TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiply<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBias<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiply<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiply<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiply<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiply<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + +TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + +TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyBias<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyBias<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyBias<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyBias<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + +TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiply<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiply<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyBias<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiply<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBias<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + TEST_CASE ("Multiply AVX512 8bit", "[multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiply<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiply<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiply<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiply<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBias<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyBias<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBias<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiply<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiply<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiply<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiply<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBias<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyBias<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBias<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + #endif + + TEST_CASE ("Multiply AVX512 16bit", "[multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiply<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiply<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + + TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyBias<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } +#endif + +} // namespace intgemm diff --git a/third_party/intgemm/test/prepare_b_quantized_transposed.cc b/third_party/intgemm/test/prepare_b_quantized_transposed.cc new file mode 100644 index 0000000000..defe9a0096 --- /dev/null +++ b/third_party/intgemm/test/prepare_b_quantized_transposed.cc @@ -0,0 +1,96 @@ +#include "test.h" +#include "../intgemm/aligned.h" +#include "../intgemm/avx2_gemm.h" +#include "../intgemm/avx512_gemm.h" +#include "../intgemm/sse2_gemm.h" +#include "../intgemm/ssse3_gemm.h" + +#include <cmath> +#include <cstring> +#include <iostream> + +namespace intgemm { +namespace { + +template <typename Backend> +void PrepareBQuantizedTransposedRef(const typename Backend::Integer* input, typename Backend::Integer* output, Index B_transposed_cols, Index B_transposed_rows) { + using vec_t = intgemm::vector_t<Backend::kUses, typename Backend::Integer>; + constexpr Index vec_len = sizeof(vec_t) / sizeof(typename Backend::Integer); + + auto output_it = output; + for (Index r = 0; r < B_transposed_rows; r += 8) + for (Index c = 0; c < B_transposed_cols; c += vec_len) + for (Index ri = 0; ri < 8; ++ri) + for (Index ci = 0; ci < vec_len; ++ci) + *output_it++ = input[(r + ri) * B_transposed_cols + c + ci]; +} + +template <typename Backend> +bool Test(const AlignedVector<typename Backend::Integer>& input, Index B_rows, Index B_cols) { + bool success = true; + + AlignedVector<typename Backend::Integer> output(input.size()); + Backend::PrepareBQuantizedTransposed(input.begin(), output.begin(), B_rows, B_cols); + + AlignedVector<typename Backend::Integer> reference(input.size()); + PrepareBQuantizedTransposedRef<Backend>(input.begin(), reference.begin(), B_rows, B_cols); + + for (std::size_t i = 0; i < output.size(); ++i) { + if (output[i] != reference[i]) { + UNSCOPED_INFO("Error at " << i << ", output = " << int(output[i]) << ", reference = " << int(reference[i])); + success = false; + break; + } + } + return success; +} + +template <typename Backend> +bool TestMany(Index B_rows, Index B_cols) { + AlignedVector<typename Backend::Integer> input(B_rows * B_cols); + + std::generate(input.begin(), input.end(), []() { + static constexpr int divider = sizeof(intgemm::vector_t<Backend::kUses, typename Backend::Integer>) / sizeof(typename Backend::Integer); + static int value = 0; + return static_cast<typename Backend::Integer>((value++) % divider); + }); + + return Test<Backend>(input, B_rows, B_cols); +} + +TEST_CASE("PrepareBQuantizedTransposed SSE2", "") { + if (kCPU < CPUType::SSE2) + return; + + CHECK(TestMany<SSE2::Kernels16>(32, 128)); +} + +TEST_CASE("PrepareBQuantizedTransposed SSSE3", "") { + if (kCPU < CPUType::SSSE3) + return; + + CHECK(TestMany<SSSE3::Kernels8>(32, 128)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("PrepareBQuantizedTransposed AVX2", "") { + if (kCPU < CPUType::AVX2) + return; + + CHECK(TestMany<AVX2::Kernels8>(32, 128)); + CHECK(TestMany<AVX2::Kernels16>(32, 128)); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + TEST_CASE("PrepareBQuantizedTransposed AVX512", "") { + if (kCPU < CPUType::AVX512BW) + return; + + CHECK(TestMany<AVX512BW::Kernels8>(64, 128)); + CHECK(TestMany<AVX512BW::Kernels16>(64, 128)); + } +#endif + +} +} diff --git a/third_party/intgemm/test/prepare_b_transposed.cc b/third_party/intgemm/test/prepare_b_transposed.cc new file mode 100644 index 0000000000..1c11fbe112 --- /dev/null +++ b/third_party/intgemm/test/prepare_b_transposed.cc @@ -0,0 +1,97 @@ +#include "test.h" +#include "../intgemm/aligned.h" +#include "../intgemm/avx2_gemm.h" +#include "../intgemm/avx512_gemm.h" +#include "../intgemm/sse2_gemm.h" +#include "../intgemm/ssse3_gemm.h" + +#include <cmath> +#include <cstring> +#include <iostream> + +namespace intgemm { +namespace { + +template <typename Backend> +void PrepareBTransposedRef(const float* input, typename Backend::Integer* output, float quant_mult, Index B_transposed_cols, Index B_transposed_rows) { + using vec_t = intgemm::vector_t<Backend::kUses, typename Backend::Integer>; + constexpr Index vec_len = sizeof(vec_t) / sizeof(typename Backend::Integer); + + for (Index i = 0; i < B_transposed_rows * B_transposed_cols / 8; i += vec_len) + for (Index j = 0; j < 8; ++j) + for (Index k = 0; k < vec_len; ++k) { + Index col = (i + k) % B_transposed_cols; + Index row = 8 * ((i + k) / B_transposed_cols) + j; + *output++ = static_cast<typename Backend::Integer>(input[row * B_transposed_cols + col] * quant_mult); + } +} + +template <typename Backend> +bool Test(const AlignedVector<float>& input, Index B_rows, Index B_cols, float quant_mult) { + bool success = true; + + AlignedVector<typename Backend::Integer> output(input.size()); + Backend::PrepareBTransposed(input.begin(), output.begin(), quant_mult, B_rows, B_cols); + + AlignedVector<typename Backend::Integer> reference(input.size()); + PrepareBTransposedRef<Backend>(input.begin(), reference.begin(), quant_mult, B_rows, B_cols); + + for (std::size_t i = 0; i < output.size(); ++i) { + if (output[i] != reference[i]) { + UNSCOPED_INFO("Error at " << i << ", output = " << int(output[i]) << ", reference = " << int(reference[i])); + success = false; + break; + } + } + return success; +} + +template <typename Backend> +bool TestMany(Index B_rows, Index B_cols, float quant_mult) { + AlignedVector<float> input(B_rows * B_cols); + + std::generate(input.begin(), input.end(), []() { + static constexpr int divider = sizeof(intgemm::vector_t<Backend::kUses, typename Backend::Integer>) / sizeof(typename Backend::Integer); + static int value = 0; + return static_cast<float>((value++) % divider); + }); + + return Test<Backend>(input, B_rows, B_cols, quant_mult); +} + +TEST_CASE("PrepareBTransposed SSE2", "") { + if (kCPU < CPUType::SSE2) + return; + + CHECK(TestMany<SSE2::Kernels16>(4, 128, 2.0f)); +} + +TEST_CASE("PrepareBTransposed SSSE3", "") { + if (kCPU < CPUType::SSSE3) + return; + + CHECK(TestMany<SSSE3::Kernels8>(4, 128, 2.0f)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("PrepareBTransposed AVX2", "") { + if (kCPU < CPUType::AVX2) + return; + + CHECK(TestMany<AVX2::Kernels8>(8, 128, 2.0f)); + CHECK(TestMany<AVX2::Kernels16>(8, 128, 2.0f)); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("PrepareBTransposed AVX512", "") { + if (kCPU < CPUType::AVX512BW) + return; + + CHECK(TestMany<AVX512BW::Kernels8>(16, 128, 2.0f)); + CHECK(TestMany<AVX512BW::Kernels16>(16, 128, 2.0f)); +} +#endif + +} +} diff --git a/third_party/intgemm/test/quantize_test.cc b/third_party/intgemm/test/quantize_test.cc new file mode 100644 index 0000000000..622ff7149f --- /dev/null +++ b/third_party/intgemm/test/quantize_test.cc @@ -0,0 +1,199 @@ +#include "test.h" +#include "../intgemm/aligned.h" +#include "../intgemm/avx2_gemm.h" +#include "../intgemm/avx512_gemm.h" +#include "../intgemm/sse2_gemm.h" +#include "../intgemm/ssse3_gemm.h" +#include "../intgemm/stats.h" + +#include <cmath> +#include <cstring> +#include <iostream> + +namespace intgemm { +namespace { + +void QuantizeRef(const float *input, int16_t *output, float quant_mult, std::size_t size) { + for (std::size_t i = 0; i < size; ++i) { + float value = roundf(input[i] * quant_mult); + value = std::max(-32768.0f, value); + value = std::min(32767.0f, value); + // float should be exact in this range. + output[i] = static_cast<int16_t>(value); + } +} + +void QuantizeRef(const float *input, int8_t *output, float quant_mult, std::size_t size) { + for (std::size_t i = 0; i < size; ++i) { + float value = roundf(input[i] * quant_mult); + value = std::max(-127.0f, value); + value = std::min(127.0f, value); + output[i] = static_cast<int8_t>(value); + } +} + +MeanStd VectorMeanStd(AlignedVector<float>& vals, int num_items, bool absolute) { + float normal_sums = 0; + float squares_sum = 0; + if (absolute) { + std::for_each(vals.begin(), vals.end(), [&] (float n) {normal_sums+=std::abs(n);}); + } else { + std::for_each(vals.begin(), vals.end(), [&] (float n) {normal_sums+=n;}); + } + std::for_each(vals.begin(), vals.end(), [&] (float n) {squares_sum+=n*n;}); + + MeanStd ret; + ret.mean = normal_sums/num_items; + ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean)); + return ret; +} + +template <MeanStd (*Backend) (const float *, const float *, bool)> +void testVectorMeanStd(int num_items, bool absolute=false) { + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + AlignedVector<float> inputVec(num_items); + + for (auto&& it : inputVec) { + it = dist(gen); + } + + MeanStd reference = VectorMeanStd(inputVec, num_items, absolute); + MeanStd fast = Backend(inputVec.begin(), inputVec.end(), absolute); + + float meanDifference = std::fabs(reference.mean - fast.mean); + float stdDifference = std::fabs(reference.stddev - fast.stddev); + float eps = 0.00002f; //Accumulating horizontal sums can lead to errors. + + CHECK_MESSAGE(meanDifference <= eps, "Items: " << num_items << " Absolute: " << absolute << " Reference mean: " << reference.mean << " actual: " << fast.mean); + CHECK_MESSAGE(stdDifference <= eps, "Items: " << num_items << " Absolute: " << absolute << " Reference mean: " << reference.stddev << " actual: " << fast.stddev); + +} + +template <class I> bool IsOff(float from, I ref, I test) { + if (ref == test) return false; + if (ref - test > 1 && test - ref > 1) return true; + float off_test = std::fabs(static_cast<float>(test) - from); + float off_ref = std::fabs(static_cast<float>(ref) - from); + // Allow 0.5 to round either way. + if (off_test > 0.49 && off_test < 0.51 && off_ref > 0.49 && off_ref < 0.51) return false; + return true; +} + +template <class Backend> bool Test(const float *input_unaligned, float quant_mult, std::size_t size) { + using Integer = typename Backend::Integer; + bool success = true; + AlignedVector<float> input(size); + std::memcpy(input.begin(), input_unaligned, sizeof(float) * size); + + AlignedVector<Integer> ref(size); + AlignedVector<Integer> test(size); + QuantizeRef(input.begin(), ref.begin(), quant_mult, static_cast<Index>(size)); + Backend::Quantize(input.begin(), test.begin(), quant_mult, static_cast<Index>(size)); + for (std::size_t i = 0; i < size; ++i) { + if (IsOff(input[i] * quant_mult, ref[i], test[i])) { + UNSCOPED_INFO("Error at " << i << " from " << input[i] << '*' << quant_mult << '=' << (input[i]*quant_mult) << " ref = " << static_cast<int>(ref[i]) << " test = " << static_cast<int>(test[i])); + success = false; + } + } + return success; +} + +template <class Backend> void TestMany(std::size_t grow) { + float input[33] = { + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, + 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, + 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}; + float corners[33] = { + -32769.f, -32768.f, -32767.f, -129.f, -128.f, -127.f, -1.f, 0.f, 1.f, + 126.f, 127.f, 128.f, 129.f, 32766.f, 32768.f, 32769.f, -1.9f, -1.5f, -1.1f, + -1.f, -0.9f, -0.5f, -0.1f, 0.0f, 0.1f, 0.5f, 0.9f, 1.0f, 1.1f, 1.5f, 1.9f, + 16056.8f, 2.5f}; + for (std::size_t len = 0; len <= 33; len += grow) { + CHECK(Test<Backend>(input, 1.0f, len)); + CHECK(Test<Backend>(input, 32.0f, len)); + CHECK(Test<Backend>(corners, 1.0f, len)); + CHECK(Test<Backend>(corners, -1.0f, len)); + CHECK(Test<Backend>(corners, -0.49f, len)); + } +} + +TEST_CASE ("Quantize SSE2", "[quantize]") { + if (kCPU < CPUType::SSE2) return; + TestMany<SSE2::Kernels16>(8); +} + +TEST_CASE ("Quantize SSSE3", "[quantize]") { + if (kCPU < CPUType::SSSE3) return; + TestMany<SSSE3::Kernels8>(1); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE ("Quantize AVX2", "[quantize]") { + if (kCPU < CPUType::AVX2) return; + TestMany<AVX2::Kernels8>(1); + TestMany<AVX2::Kernels16>(16); +} +#endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE ("Quantize AVX512", "[quantize]") { + if (kCPU < CPUType::AVX512BW) return; + TestMany<AVX512BW::Kernels8>(1); + TestMany<AVX512BW::Kernels16>(16); +} +#endif + +TEST_CASE("QuantizeStd SSSE3", "[VectorMeanStd]") { + if (kCPU < CPUType::SSSE3) return; + testVectorMeanStd<SSE2::VectorMeanStd>(64); + testVectorMeanStd<SSE2::VectorMeanStd>(64, true); + testVectorMeanStd<SSE2::VectorMeanStd>(256); + testVectorMeanStd<SSE2::VectorMeanStd>(256, true); + testVectorMeanStd<SSE2::VectorMeanStd>(2048); + testVectorMeanStd<SSE2::VectorMeanStd>(2048, true); + testVectorMeanStd<SSE2::VectorMeanStd>(65536); + testVectorMeanStd<SSE2::VectorMeanStd>(65536, true); + testVectorMeanStd<SSE2::VectorMeanStd>(81920); + testVectorMeanStd<SSE2::VectorMeanStd>(81920, true); + testVectorMeanStd<SSE2::VectorMeanStd>(120832); + testVectorMeanStd<SSE2::VectorMeanStd>(120832, true); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("QuantizeStd AVX2", "[VectorMeanStd]") { + if (kCPU < CPUType::AVX2) return; + testVectorMeanStd<AVX2::VectorMeanStd>(64); + testVectorMeanStd<AVX2::VectorMeanStd>(64, true); + testVectorMeanStd<AVX2::VectorMeanStd>(256); + testVectorMeanStd<AVX2::VectorMeanStd>(256, true); + testVectorMeanStd<AVX2::VectorMeanStd>(2048); + testVectorMeanStd<AVX2::VectorMeanStd>(2048, true); + testVectorMeanStd<AVX2::VectorMeanStd>(65536); + testVectorMeanStd<AVX2::VectorMeanStd>(65536, true); + testVectorMeanStd<AVX2::VectorMeanStd>(81920); + testVectorMeanStd<AVX2::VectorMeanStd>(81920, true); + testVectorMeanStd<AVX2::VectorMeanStd>(120832); + testVectorMeanStd<AVX2::VectorMeanStd>(120832, true); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("QuantizeStd AVX512BW", "[VectorMeanStd]") { + if (kCPU < CPUType::AVX512BW) return; + testVectorMeanStd<AVX512BW::VectorMeanStd>(64); + testVectorMeanStd<AVX512BW::VectorMeanStd>(64, true); + testVectorMeanStd<AVX512BW::VectorMeanStd>(256); + testVectorMeanStd<AVX512BW::VectorMeanStd>(256, true); + testVectorMeanStd<AVX512BW::VectorMeanStd>(2048); + testVectorMeanStd<AVX512BW::VectorMeanStd>(2048, true); + testVectorMeanStd<AVX512BW::VectorMeanStd>(65536); + testVectorMeanStd<AVX512BW::VectorMeanStd>(65536, true); + testVectorMeanStd<AVX512BW::VectorMeanStd>(81920); + testVectorMeanStd<AVX512BW::VectorMeanStd>(81920, true); + testVectorMeanStd<AVX512BW::VectorMeanStd>(120832); + testVectorMeanStd<AVX512BW::VectorMeanStd>(120832, true); +} +#endif + +} // namespace +} // namespace intgemm diff --git a/third_party/intgemm/test/test.cc b/third_party/intgemm/test/test.cc new file mode 100644 index 0000000000..45c27ad047 --- /dev/null +++ b/third_party/intgemm/test/test.cc @@ -0,0 +1,27 @@ +#define CATCH_CONFIG_RUNNER +#include "test.h" + +#include <cmath> + +int main(int argc, char ** argv) { + return Catch::Session().run(argc, argv); +} + +namespace intgemm { + +void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, + float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) { + float int_sum = 0.0, float_sum = 0.0; + for (std::size_t i = 0; i < size; ++i) { + float int_diff = int_ref[i] - int_test[i]; + float float_diff = float_ref[i] - int_test[i]; + CHECK_MESSAGE(std::fabs(int_diff) <= int_tolerance, test_info << "Inaccurate compared to int reference at " << i << ' ' << int_ref[i] << ' ' << int_test[i]); + CHECK_MESSAGE(std::fabs(float_diff) <= float_tolerance, test_info << "Inaccurate compared to float reference at " << i << ' ' << float_ref[i] << ' ' << int_test[i]); + int_sum += int_diff * int_diff; + float_sum += float_diff * float_diff; + } + CHECK_MESSAGE(std::fabs(sqrt(float_sum / size)) <= MSE_float_tolerance, test_info << "Float MSE = " << sqrt(float_sum / size)); + CHECK_MESSAGE(std::fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size)); +} + +} // namespace intgemm diff --git a/third_party/intgemm/test/test.h b/third_party/intgemm/test/test.h new file mode 100644 index 0000000000..1f884c512f --- /dev/null +++ b/third_party/intgemm/test/test.h @@ -0,0 +1,132 @@ +#pragma once + +#include "intgemm/intgemm_config.h" + +#include "3rd_party/catch.hpp" +#include "../intgemm/intgemm.h" +#include "../intgemm/aligned.h" + +#include <cmath> +#include <sstream> +#include <iostream> +#include <iomanip> + +#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0) +#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0) +#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0) +#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0) + +#define CHECK_EPS(actual, expected, epsilon) \ + do { \ + if (std::fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ + else { CHECK((actual) == (expected)); } \ + } while(0) + +#define KERNEL_TEST_CASE(name) TEST_CASE("Kernel: " name, "[kernel_test]") + +namespace intgemm { + +template <typename Type> +void Compare(const Type* reference, const Type* actual, Index size) { + for (Index i = 0; i < size; ++i) { + INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]); + CHECK(reference[i] == actual[i]); + } +} + +template <typename Type> +void CompareEps(const Type* reference, const Type* actual, Index size, Type epsilon) { + for (Index i = 0; i < size; ++i) { + INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]); + // Ratio to maximum value. + float threshold = epsilon * std::max<float>(0.01f, std::fabs(reference[i])); + CHECK(std::fabs(reference[i] - actual[i]) < threshold); + } +} + +void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, + std::size_t size, std::string test_info, float int_tolerance, + float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); + +template <typename Type> +std::string PrintMatrix(const Type *mem, Index rows, Index cols) { + std::ostringstream out; + for (Index r = 0; r < rows; ++r) { + for (Index c = 0; c < cols; ++c) { + out << std::setw(4) << (int64_t) mem[r * cols + c] << ' '; + } + out << '\n'; + } + return out.str(); +} + +/* + * References + */ +namespace references { + +// Quantize +template <typename Type> +void Quantize(const float* input, Type* output, float quant_mult, Index size) { + for (Index i = 0; i < size; ++i) { + float value = roundf(input[i] * quant_mult); + value = std::max<float>(std::numeric_limits<Type>::min(), value); + value = std::min<float>(std::numeric_limits<Type>::max(), value); + output[i] = value; + } +} + +/* + * Multiply C = A x B + * + * Notes: A and B has to be both integers or both floating points. + * + * Callback takes two arguments: + * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B. + * - Object containing information about position in output matrix - callbacks::OutputBufferInfo. + */ +template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback, + typename std::enable_if< + (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) || + (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value) + >::type* = nullptr> +void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { + using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type; + + for (Index r = 0; r < A_rows; ++r) { + for (Index c = 0; c < B_cols; ++c) { + IntermediateType sum = 0; + for (Index k = 0; k < width; ++k) { + sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]); + } + C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); + } + } +} + +// Matrix rearragement +template <typename Type> +void Rearragement(const Type* input, Type* output, Index simd, Index unroll, Index rows, Index cols) { + for (Index c = 0; c < cols; c += unroll) { + for (Index r = 0; r < rows; r += simd) { + for (Index i = 0; i < unroll; ++i) + for (Index j = 0; j < simd; ++j) + output[simd * i + j] = input[cols * r + c + cols * j + i]; + + output += unroll * simd; + } + } +} + +// Transpose +template <typename Type> +void Transpose(const Type* input, Type* output, Index rows, Index cols) { + for (Index r = 0; r < rows; ++r) { + for (Index c = 0; c < cols; ++c) { + output[rows * c + r] = input[cols * r + c]; + } + } +} + +} // namespace references +} // namespace intgemm diff --git a/third_party/intgemm/test/utils_test.cc b/third_party/intgemm/test/utils_test.cc new file mode 100644 index 0000000000..e7d07e8483 --- /dev/null +++ b/third_party/intgemm/test/utils_test.cc @@ -0,0 +1,45 @@ +#include "test.h" +#include "../intgemm/utils.h" + +namespace intgemm { +namespace { + +TEST_CASE("Factorial",) { + CHECK(factorial(0) == 1); + CHECK(factorial(1) == 1); + CHECK(factorial(2) == 2); + CHECK(factorial(3) == 6); + CHECK(factorial(4) == 24); + + // Maximum result that fits in unsinged long long + CHECK(factorial(20) == 2432902008176640000); +} + +TEST_CASE("Expi (negative)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(-1), 0.3678794411714423, eps); + CHECK_EPS(expi(-2), 0.1353352832366127, eps); + CHECK_EPS(expi(-10), 0.0000453999297625, eps); +} + +TEST_CASE("Expi (zero)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(0), 1.0, eps); +} + +TEST_CASE("Expi (positive)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(1), 2.7182818284590452, eps); + CHECK_EPS(expi(2), 7.3890560989306502, eps); + CHECK_EPS(expi(10), 22026.4657948067165170, eps); +} + +TEST_CASE("Round up",) { + CHECK(round_up(0, 5) == 0); + CHECK(round_up(1, 5) == 5); + CHECK(round_up(4, 5) == 5); + CHECK(round_up(6, 5) == 10); +} + +} +} |