summaryrefslogtreecommitdiffstats
path: root/dom/webgpu/tests/cts/checkout/src/webgpu/util/conversion.ts
diff options
context:
space:
mode:
Diffstat (limited to 'dom/webgpu/tests/cts/checkout/src/webgpu/util/conversion.ts')
-rw-r--r--dom/webgpu/tests/cts/checkout/src/webgpu/util/conversion.ts1076
1 files changed, 1076 insertions, 0 deletions
diff --git a/dom/webgpu/tests/cts/checkout/src/webgpu/util/conversion.ts b/dom/webgpu/tests/cts/checkout/src/webgpu/util/conversion.ts
new file mode 100644
index 0000000000..d81f22defe
--- /dev/null
+++ b/dom/webgpu/tests/cts/checkout/src/webgpu/util/conversion.ts
@@ -0,0 +1,1076 @@
+import { Colors } from '../../common/util/colors.js';
+import { assert, TypedArrayBufferView, unreachable } from '../../common/util/util.js';
+import { Float16Array } from '../../external/petamoriken/float16/float16.js';
+
+import { kBit } from './constants.js';
+import {
+ cartesianProduct,
+ clamp,
+ correctlyRoundedF16,
+ isFiniteF16,
+ isSubnormalNumberF16,
+ isSubnormalNumberF32,
+} from './math.js';
+
+/**
+ * Encodes a JS `number` into a "normalized" (unorm/snorm) integer representation with `bits` bits.
+ * Input must be between -1 and 1 if signed, or 0 and 1 if unsigned.
+ *
+ * MAINTENANCE_TODO: See if performance of texel_data improves if this function is pre-specialized
+ * for a particular `bits`/`signed`.
+ */
+export function floatAsNormalizedInteger(float: number, bits: number, signed: boolean): number {
+ if (signed) {
+ assert(float >= -1 && float <= 1, () => `${float} out of bounds of snorm`);
+ const max = Math.pow(2, bits - 1) - 1;
+ return Math.round(float * max);
+ } else {
+ assert(float >= 0 && float <= 1, () => `${float} out of bounds of unorm`);
+ const max = Math.pow(2, bits) - 1;
+ return Math.round(float * max);
+ }
+}
+
+/**
+ * Decodes a JS `number` from a "normalized" (unorm/snorm) integer representation with `bits` bits.
+ * Input must be an integer in the range of the specified unorm/snorm type.
+ */
+export function normalizedIntegerAsFloat(integer: number, bits: number, signed: boolean): number {
+ assert(Number.isInteger(integer));
+ if (signed) {
+ const max = Math.pow(2, bits - 1) - 1;
+ assert(integer >= -max - 1 && integer <= max);
+ if (integer === -max - 1) {
+ integer = -max;
+ }
+ return integer / max;
+ } else {
+ const max = Math.pow(2, bits) - 1;
+ assert(integer >= 0 && integer <= max);
+ return integer / max;
+ }
+}
+
+/**
+ * Encodes a JS `number` into an IEEE754 floating point number with the specified number of
+ * sign, exponent, mantissa bits, and exponent bias.
+ * Returns the result as an integer-valued JS `number`.
+ *
+ * Does not handle clamping, overflow, or denormal inputs.
+ * On underflow (result is subnormal), rounds to (signed) zero.
+ *
+ * MAINTENANCE_TODO: Replace usages of this with numberToFloatBits.
+ */
+export function float32ToFloatBits(
+ n: number,
+ signBits: 0 | 1,
+ exponentBits: number,
+ mantissaBits: number,
+ bias: number
+): number {
+ assert(exponentBits <= 8);
+ assert(mantissaBits <= 23);
+ assert(Number.isFinite(n));
+
+ if (n === 0) {
+ return 0;
+ }
+
+ if (signBits === 0) {
+ assert(n >= 0);
+ }
+
+ const buf = new DataView(new ArrayBuffer(Float32Array.BYTES_PER_ELEMENT));
+ buf.setFloat32(0, n, true);
+ const bits = buf.getUint32(0, true);
+ // bits (32): seeeeeeeefffffffffffffffffffffff
+
+ const mantissaBitsToDiscard = 23 - mantissaBits;
+
+ // 0 or 1
+ const sign = (bits >> 31) & signBits;
+
+ // >> to remove mantissa, & to remove sign, - 127 to remove bias.
+ const exp = ((bits >> 23) & 0xff) - 127;
+
+ // Convert to the new biased exponent.
+ const newBiasedExp = bias + exp;
+ assert(newBiasedExp < 1 << exponentBits, () => `input number ${n} overflows target type`);
+
+ if (newBiasedExp <= 0) {
+ // Result is subnormal or zero. Round to (signed) zero.
+ return sign << (exponentBits + mantissaBits);
+ } else {
+ // Mask only the mantissa, and discard the lower bits.
+ const newMantissa = (bits & 0x7fffff) >> mantissaBitsToDiscard;
+ return (sign << (exponentBits + mantissaBits)) | (newBiasedExp << mantissaBits) | newMantissa;
+ }
+}
+
+/**
+ * Encodes a JS `number` into an IEEE754 16 bit floating point number.
+ * Returns the result as an integer-valued JS `number`.
+ *
+ * Does not handle clamping, overflow, or denormal inputs.
+ * On underflow (result is subnormal), rounds to (signed) zero.
+ */
+export function float32ToFloat16Bits(n: number) {
+ return float32ToFloatBits(n, 1, 5, 10, 15);
+}
+
+/**
+ * Decodes an IEEE754 16 bit floating point number into a JS `number` and returns.
+ */
+export function float16BitsToFloat32(float16Bits: number): number {
+ return floatBitsToNumber(float16Bits, kFloat16Format);
+}
+
+type FloatFormat = { signed: 0 | 1; exponentBits: number; mantissaBits: number; bias: number };
+
+/** FloatFormat defining IEEE754 32-bit float. */
+export const kFloat32Format = { signed: 1, exponentBits: 8, mantissaBits: 23, bias: 127 } as const;
+/** FloatFormat defining IEEE754 16-bit float. */
+export const kFloat16Format = { signed: 1, exponentBits: 5, mantissaBits: 10, bias: 15 } as const;
+
+/**
+ * Once-allocated ArrayBuffer/views to avoid overhead of allocation when converting between numeric formats
+ *
+ * workingData* is shared between multiple functions in this file, so to avoid re-entrancy problems, make sure in
+ * functions that use it that they don't call themselves or other functions that use workingData*.
+ */
+const workingData = new ArrayBuffer(4);
+const workingDataU32 = new Uint32Array(workingData);
+const workingDataU16 = new Uint16Array(workingData);
+const workingDataU8 = new Uint8Array(workingData);
+const workingDataF32 = new Float32Array(workingData);
+const workingDataF16 = new Float16Array(workingData);
+const workingDataI16 = new Int16Array(workingData);
+const workingDataI8 = new Int8Array(workingData);
+
+/** Bitcast u32 (represented as integer Number) to f32 (represented as floating-point Number). */
+export function float32BitsToNumber(bits: number): number {
+ workingDataU32[0] = bits;
+ return workingDataF32[0];
+}
+/** Bitcast f32 (represented as floating-point Number) to u32 (represented as integer Number). */
+export function numberToFloat32Bits(number: number): number {
+ workingDataF32[0] = number;
+ return workingDataU32[0];
+}
+
+/**
+ * Decodes an IEEE754 float with the supplied format specification into a JS number.
+ *
+ * The format MUST be no larger than a 32-bit float.
+ */
+export function floatBitsToNumber(bits: number, fmt: FloatFormat): number {
+ // Pad the provided bits out to f32, then convert to a `number` with the wrong bias.
+ // E.g. for f16 to f32:
+ // - f16: S EEEEE MMMMMMMMMM
+ // ^ 000^^^^^ ^^^^^^^^^^0000000000000
+ // - f32: S eeeEEEEE MMMMMMMMMMmmmmmmmmmmmmm
+
+ const kNonSignBits = fmt.exponentBits + fmt.mantissaBits;
+ const kNonSignBitsMask = (1 << kNonSignBits) - 1;
+ const expAndMantBits = bits & kNonSignBitsMask;
+ let f32BitsWithWrongBias = expAndMantBits << (kFloat32Format.mantissaBits - fmt.mantissaBits);
+ f32BitsWithWrongBias |= (bits << (31 - kNonSignBits)) & 0x8000_0000;
+ const numberWithWrongBias = float32BitsToNumber(f32BitsWithWrongBias);
+ return numberWithWrongBias * 2 ** (kFloat32Format.bias - fmt.bias);
+}
+
+/**
+ * Encodes a JS `number` into an IEEE754 floating point number with the specified format.
+ * Returns the result as an integer-valued JS `number`.
+ *
+ * Does not handle clamping, overflow, or denormal inputs.
+ * On underflow (result is subnormal), rounds to (signed) zero.
+ */
+export function numberToFloatBits(number: number, fmt: FloatFormat): number {
+ return float32ToFloatBits(number, fmt.signed, fmt.exponentBits, fmt.mantissaBits, fmt.bias);
+}
+
+/**
+ * Given a floating point number (as an integer representing its bits), computes how many ULPs it is
+ * from zero.
+ *
+ * Subnormal numbers are skipped, so that 0 is one ULP from the minimum normal number.
+ * Subnormal values are flushed to 0.
+ * Positive and negative 0 are both considered to be 0 ULPs from 0.
+ */
+export function floatBitsToNormalULPFromZero(bits: number, fmt: FloatFormat): number {
+ const mask_sign = fmt.signed << (fmt.exponentBits + fmt.mantissaBits);
+ const mask_expt = ((1 << fmt.exponentBits) - 1) << fmt.mantissaBits;
+ const mask_mant = (1 << fmt.mantissaBits) - 1;
+ const mask_rest = mask_expt | mask_mant;
+
+ assert(fmt.exponentBits + fmt.mantissaBits <= 31);
+
+ const sign = bits & mask_sign ? -1 : 1;
+ const rest = bits & mask_rest;
+ const subnormal_or_zero = (bits & mask_expt) === 0;
+ const infinity_or_nan = (bits & mask_expt) === mask_expt;
+ assert(!infinity_or_nan, 'no ulp representation for infinity/nan');
+
+ // The first normal number is mask_mant+1, so subtract mask_mant to make min_normal - zero = 1ULP.
+ const abs_ulp_from_zero = subnormal_or_zero ? 0 : rest - mask_mant;
+ return sign * abs_ulp_from_zero;
+}
+
+/**
+ * Encodes three JS `number` values into RGB9E5, returned as an integer-valued JS `number`.
+ *
+ * RGB9E5 represents three partial-precision floating-point numbers encoded into a single 32-bit
+ * value all sharing the same 5-bit exponent.
+ * There is no sign bit, and there is a shared 5-bit biased (15) exponent and a 9-bit
+ * mantissa for each channel. The mantissa does NOT have an implicit leading "1.",
+ * and instead has an implicit leading "0.".
+ */
+export function packRGB9E5UFloat(r: number, g: number, b: number): number {
+ for (const v of [r, g, b]) {
+ assert(v >= 0 && v < Math.pow(2, 16));
+ }
+
+ const buf = new DataView(new ArrayBuffer(Float32Array.BYTES_PER_ELEMENT));
+ const extractMantissaAndExponent = (n: number) => {
+ const mantissaBits = 9;
+ buf.setFloat32(0, n, true);
+ const bits = buf.getUint32(0, true);
+ // >> to remove mantissa, & to remove sign
+ let biasedExponent = (bits >> 23) & 0xff;
+ const mantissaBitsToDiscard = 23 - mantissaBits;
+ let mantissa = (bits & 0x7fffff) >> mantissaBitsToDiscard;
+
+ // RGB9E5UFloat has an implicit leading 0. instead of a leading 1.,
+ // so we need to move the 1. into the mantissa and bump the exponent.
+ // For float32 encoding, the leading 1 is only present if the biased
+ // exponent is non-zero.
+ if (biasedExponent !== 0) {
+ mantissa = (mantissa >> 1) | 0b100000000;
+ biasedExponent += 1;
+ }
+ return { biasedExponent, mantissa };
+ };
+
+ const { biasedExponent: rExp, mantissa: rOrigMantissa } = extractMantissaAndExponent(r);
+ const { biasedExponent: gExp, mantissa: gOrigMantissa } = extractMantissaAndExponent(g);
+ const { biasedExponent: bExp, mantissa: bOrigMantissa } = extractMantissaAndExponent(b);
+
+ // Use the largest exponent, and shift the mantissa accordingly
+ const exp = Math.max(rExp, gExp, bExp);
+ const rMantissa = rOrigMantissa >> (exp - rExp);
+ const gMantissa = gOrigMantissa >> (exp - gExp);
+ const bMantissa = bOrigMantissa >> (exp - bExp);
+
+ const bias = 15;
+ const biasedExp = exp === 0 ? 0 : exp - 127 + bias;
+ assert(biasedExp >= 0 && biasedExp <= 31);
+ return rMantissa | (gMantissa << 9) | (bMantissa << 18) | (biasedExp << 27);
+}
+
+/**
+ * Quantizes two f32s to f16 and then packs them in a u32
+ *
+ * This should implement the same behaviour as the builtin `pack2x16float` from
+ * WGSL.
+ *
+ * Caller is responsible to ensuring inputs are f32s
+ *
+ * @param x first f32 to be packed
+ * @param y second f32 to be packed
+ * @returns an array of possible results for pack2x16float. Elements are either
+ * a number or undefined.
+ * undefined indicates that any value is valid, since the input went
+ * out of bounds.
+ */
+export function pack2x16float(x: number, y: number): (number | undefined)[] {
+ // Generates all possible valid u16 bit fields for a given f32 to f16 conversion.
+ // Assumes FTZ for both the f32 and f16 value is allowed.
+ const generateU16s = (n: number): number[] => {
+ let contains_subnormals = isSubnormalNumberF32(n);
+ const n_f16s = correctlyRoundedF16(n);
+ contains_subnormals ||= n_f16s.some(isSubnormalNumberF16);
+
+ const n_u16s = n_f16s.map(f16 => {
+ workingDataF16[0] = f16;
+ return workingDataU16[0];
+ });
+
+ const contains_poszero = n_u16s.some(u => u === kBit.f16.positive.zero);
+ const contains_negzero = n_u16s.some(u => u === kBit.f16.negative.zero);
+ if (!contains_negzero && (contains_poszero || contains_subnormals)) {
+ n_u16s.push(kBit.f16.negative.zero);
+ }
+
+ if (!contains_poszero && (contains_negzero || contains_subnormals)) {
+ n_u16s.push(kBit.f16.positive.zero);
+ }
+
+ return n_u16s;
+ };
+
+ if (!isFiniteF16(x) || !isFiniteF16(y)) {
+ // This indicates any value is valid, so it isn't worth bothering
+ // calculating the more restrictive possibilities.
+ return [undefined];
+ }
+
+ const results = new Array<number>();
+ for (const p of cartesianProduct(generateU16s(x), generateU16s(y))) {
+ assert(p.length === 2, 'cartesianProduct of 2 arrays returned an entry with not 2 elements');
+ workingDataU16[0] = p[0];
+ workingDataU16[1] = p[1];
+ results.push(workingDataU32[0]);
+ }
+
+ return results;
+}
+
+/**
+ * Converts two normalized f32s to i16s and then packs them in a u32
+ *
+ * This should implement the same behaviour as the builtin `pack2x16snorm` from
+ * WGSL.
+ *
+ * Caller is responsible to ensuring inputs are normalized f32s
+ *
+ * @param x first f32 to be packed
+ * @param y second f32 to be packed
+ * @returns a number that is expected result of pack2x16snorm.
+ */
+export function pack2x16snorm(x: number, y: number): number {
+ // Converts f32 to i16 via the pack2x16snorm formula.
+ // FTZ is not explicitly handled, because all subnormals will produce a value
+ // between 0 and 1, but significantly away from the edges, so floor goes to 0.
+ const generateI16 = (n: number): number => {
+ return Math.floor(0.5 + 32767 * Math.min(1, Math.max(-1, n)));
+ };
+
+ workingDataI16[0] = generateI16(x);
+ workingDataI16[1] = generateI16(y);
+
+ return workingDataU32[0];
+}
+
+/**
+ * Converts two normalized f32s to u16s and then packs them in a u32
+ *
+ * This should implement the same behaviour as the builtin `pack2x16unorm` from
+ * WGSL.
+ *
+ * Caller is responsible to ensuring inputs are normalized f32s
+ *
+ * @param x first f32 to be packed
+ * @param y second f32 to be packed
+ * @returns an number that is expected result of pack2x16unorm.
+ */
+export function pack2x16unorm(x: number, y: number): number {
+ // Converts f32 to u16 via the pack2x16unorm formula.
+ // FTZ is not explicitly handled, because all subnormals will produce a value
+ // between 0.5 and much less than 1, so floor goes to 0.
+ const generateU16 = (n: number): number => {
+ return Math.floor(0.5 + 65535 * Math.min(1, Math.max(0, n)));
+ };
+
+ workingDataU16[0] = generateU16(x);
+ workingDataU16[1] = generateU16(y);
+
+ return workingDataU32[0];
+}
+
+/**
+ * Converts four normalized f32s to i8s and then packs them in a u32
+ *
+ * This should implement the same behaviour as the builtin `pack4x8snorm` from
+ * WGSL.
+ *
+ * Caller is responsible to ensuring inputs are normalized f32s
+ *
+ * @param vals four f32s to be packed
+ * @returns a number that is expected result of pack4x8usorm.
+ */
+export function pack4x8snorm(...vals: [number, number, number, number]): number {
+ // Converts f32 to u8 via the pack4x8snorm formula.
+ // FTZ is not explicitly handled, because all subnormals will produce a value
+ // between 0 and 1, so floor goes to 0.
+ const generateI8 = (n: number): number => {
+ return Math.floor(0.5 + 127 * Math.min(1, Math.max(-1, n)));
+ };
+
+ for (const idx in vals) {
+ workingDataI8[idx] = generateI8(vals[idx]);
+ }
+
+ return workingDataU32[0];
+}
+
+/**
+ * Converts four normalized f32s to u8s and then packs them in a u32
+ *
+ * This should implement the same behaviour as the builtin `pack4x8unorm` from
+ * WGSL.
+ *
+ * Caller is responsible to ensuring inputs are normalized f32s
+ *
+ * @param vals four f32s to be packed
+ * @returns a number that is expected result of pack4x8unorm.
+ */
+export function pack4x8unorm(...vals: [number, number, number, number]): number {
+ // Converts f32 to u8 via the pack4x8unorm formula.
+ // FTZ is not explicitly handled, because all subnormals will produce a value
+ // between 0.5 and much less than 1, so floor goes to 0.
+ const generateU8 = (n: number): number => {
+ return Math.floor(0.5 + 255 * Math.min(1, Math.max(0, n)));
+ };
+
+ for (const idx in vals) {
+ workingDataU8[idx] = generateU8(vals[idx]);
+ }
+
+ return workingDataU32[0];
+}
+
+/**
+ * Asserts that a number is within the representable (inclusive) of the integer type with the
+ * specified number of bits and signedness.
+ *
+ * MAINTENANCE_TODO: Assert isInteger? Then this function "asserts that a number is representable"
+ * by the type.
+ */
+export function assertInIntegerRange(n: number, bits: number, signed: boolean): void {
+ if (signed) {
+ const min = -Math.pow(2, bits - 1);
+ const max = Math.pow(2, bits - 1) - 1;
+ assert(n >= min && n <= max);
+ } else {
+ const max = Math.pow(2, bits) - 1;
+ assert(n >= 0 && n <= max);
+ }
+}
+
+/**
+ * Converts a linear value into a "gamma"-encoded value using the sRGB-clamped transfer function.
+ */
+export function gammaCompress(n: number): number {
+ n = n <= 0.0031308 ? (323 * n) / 25 : (211 * Math.pow(n, 5 / 12) - 11) / 200;
+ return clamp(n, { min: 0, max: 1 });
+}
+
+/**
+ * Converts a "gamma"-encoded value into a linear value using the sRGB-clamped transfer function.
+ */
+export function gammaDecompress(n: number): number {
+ n = n <= 0.04045 ? (n * 25) / 323 : Math.pow((200 * n + 11) / 211, 12 / 5);
+ return clamp(n, { min: 0, max: 1 });
+}
+
+/** Converts a 32-bit float value to a 32-bit unsigned integer value */
+export function float32ToUint32(f32: number): number {
+ const f32Arr = new Float32Array(1);
+ f32Arr[0] = f32;
+ const u32Arr = new Uint32Array(f32Arr.buffer);
+ return u32Arr[0];
+}
+
+/** Converts a 32-bit unsigned integer value to a 32-bit float value */
+export function uint32ToFloat32(u32: number): number {
+ const u32Arr = new Uint32Array(1);
+ u32Arr[0] = u32;
+ const f32Arr = new Float32Array(u32Arr.buffer);
+ return f32Arr[0];
+}
+
+/** Converts a 32-bit float value to a 32-bit signed integer value */
+export function float32ToInt32(f32: number): number {
+ const f32Arr = new Float32Array(1);
+ f32Arr[0] = f32;
+ const i32Arr = new Int32Array(f32Arr.buffer);
+ return i32Arr[0];
+}
+
+/** Converts a 32-bit unsigned integer value to a 32-bit signed integer value */
+export function uint32ToInt32(u32: number): number {
+ const u32Arr = new Uint32Array(1);
+ u32Arr[0] = u32;
+ const i32Arr = new Int32Array(u32Arr.buffer);
+ return i32Arr[0];
+}
+
+/** Converts a 16-bit float value to a 16-bit unsigned integer value */
+export function float16ToUint16(f16: number): number {
+ const f16Arr = new Float16Array(1);
+ f16Arr[0] = f16;
+ const u16Arr = new Uint16Array(f16Arr.buffer);
+ return u16Arr[0];
+}
+
+/** Converts a 16-bit unsigned integer value to a 16-bit float value */
+export function uint16ToFloat16(u16: number): number {
+ const u16Arr = new Uint16Array(1);
+ u16Arr[0] = u16;
+ const f16Arr = new Float16Array(u16Arr.buffer);
+ return f16Arr[0];
+}
+
+/** Converts a 16-bit float value to a 16-bit signed integer value */
+export function float16ToInt16(f16: number): number {
+ const f16Arr = new Float16Array(1);
+ f16Arr[0] = f16;
+ const i16Arr = new Int16Array(f16Arr.buffer);
+ return i16Arr[0];
+}
+
+/** A type of number representable by Scalar. */
+export type ScalarKind =
+ | 'f64'
+ | 'f32'
+ | 'f16'
+ | 'u32'
+ | 'u16'
+ | 'u8'
+ | 'i32'
+ | 'i16'
+ | 'i8'
+ | 'bool';
+
+/** ScalarType describes the type of WGSL Scalar. */
+export class ScalarType {
+ readonly kind: ScalarKind; // The named type
+ readonly _size: number; // In bytes
+ readonly read: (buf: Uint8Array, offset: number) => Scalar; // reads a scalar from a buffer
+
+ constructor(kind: ScalarKind, size: number, read: (buf: Uint8Array, offset: number) => Scalar) {
+ this.kind = kind;
+ this._size = size;
+ this.read = read;
+ }
+
+ public toString(): string {
+ return this.kind;
+ }
+
+ public get size(): number {
+ return this._size;
+ }
+}
+
+/** ScalarType describes the type of WGSL Vector. */
+export class VectorType {
+ readonly width: number; // Number of elements in the vector
+ readonly elementType: ScalarType; // Element type
+
+ constructor(width: number, elementType: ScalarType) {
+ this.width = width;
+ this.elementType = elementType;
+ }
+
+ /**
+ * @returns a vector constructed from the values read from the buffer at the
+ * given byte offset
+ */
+ public read(buf: Uint8Array, offset: number): Vector {
+ const elements: Array<Scalar> = [];
+ for (let i = 0; i < this.width; i++) {
+ elements[i] = this.elementType.read(buf, offset);
+ offset += this.elementType.size;
+ }
+ return new Vector(elements);
+ }
+
+ public toString(): string {
+ return `vec${this.width}<${this.elementType}>`;
+ }
+
+ public get size(): number {
+ return this.elementType.size * this.width;
+ }
+}
+
+// Maps a string representation of a vector type to vector type.
+const vectorTypes = new Map<string, VectorType>();
+
+export function TypeVec(width: number, elementType: ScalarType): VectorType {
+ const key = `${elementType.toString()} ${width}}`;
+ let ty = vectorTypes.get(key);
+ if (ty !== undefined) {
+ return ty;
+ }
+ ty = new VectorType(width, elementType);
+ vectorTypes.set(key, ty);
+ return ty;
+}
+
+/** Type is a ScalarType or VectorType. */
+export type Type = ScalarType | VectorType;
+
+export const TypeI32 = new ScalarType('i32', 4, (buf: Uint8Array, offset: number) =>
+ i32(new Int32Array(buf.buffer, offset)[0])
+);
+export const TypeU32 = new ScalarType('u32', 4, (buf: Uint8Array, offset: number) =>
+ u32(new Uint32Array(buf.buffer, offset)[0])
+);
+export const TypeF64 = new ScalarType('f64', 8, (buf: Uint8Array, offset: number) =>
+ f32(new Float64Array(buf.buffer, offset)[0])
+);
+export const TypeF32 = new ScalarType('f32', 4, (buf: Uint8Array, offset: number) =>
+ f32(new Float32Array(buf.buffer, offset)[0])
+);
+export const TypeI16 = new ScalarType('i16', 2, (buf: Uint8Array, offset: number) =>
+ i16(new Int16Array(buf.buffer, offset)[0])
+);
+export const TypeU16 = new ScalarType('u16', 2, (buf: Uint8Array, offset: number) =>
+ u16(new Uint16Array(buf.buffer, offset)[0])
+);
+export const TypeF16 = new ScalarType('f16', 2, (buf: Uint8Array, offset: number) =>
+ f16Bits(new Uint16Array(buf.buffer, offset)[0])
+);
+export const TypeI8 = new ScalarType('i8', 1, (buf: Uint8Array, offset: number) =>
+ i8(new Int8Array(buf.buffer, offset)[0])
+);
+export const TypeU8 = new ScalarType('u8', 1, (buf: Uint8Array, offset: number) =>
+ u8(new Uint8Array(buf.buffer, offset)[0])
+);
+export const TypeBool = new ScalarType('bool', 4, (buf: Uint8Array, offset: number) =>
+ bool(new Uint32Array(buf.buffer, offset)[0] !== 0)
+);
+
+/** @returns the ScalarType from the ScalarKind */
+export function scalarType(kind: ScalarKind): ScalarType {
+ switch (kind) {
+ case 'f64':
+ return TypeF64;
+ case 'f32':
+ return TypeF32;
+ case 'f16':
+ return TypeF16;
+ case 'u32':
+ return TypeU32;
+ case 'u16':
+ return TypeU16;
+ case 'u8':
+ return TypeU8;
+ case 'i32':
+ return TypeI32;
+ case 'i16':
+ return TypeI16;
+ case 'i8':
+ return TypeI8;
+ case 'bool':
+ return TypeBool;
+ }
+}
+
+/** @returns the number of scalar (element) types of the given Type */
+export function numElementsOf(ty: Type): number {
+ if (ty instanceof ScalarType) {
+ return 1;
+ }
+ if (ty instanceof VectorType) {
+ return ty.width;
+ }
+ throw new Error(`unhandled type ${ty}`);
+}
+
+/** @returns the scalar (element) type of the given Type */
+export function scalarTypeOf(ty: Type): ScalarType {
+ if (ty instanceof ScalarType) {
+ return ty;
+ }
+ if (ty instanceof VectorType) {
+ return ty.elementType;
+ }
+ throw new Error(`unhandled type ${ty}`);
+}
+
+/** ScalarValue is the JS type that can be held by a Scalar */
+type ScalarValue = boolean | number;
+
+/** Class that encapsulates a single scalar value of various types. */
+export class Scalar {
+ readonly value: ScalarValue; // The scalar value
+ readonly type: ScalarType; // The type of the scalar
+ readonly bits: Uint8Array; // The scalar value packed in a Uint8Array
+
+ public constructor(type: ScalarType, value: ScalarValue, bits: TypedArrayBufferView) {
+ this.value = value;
+ this.type = type;
+ this.bits = new Uint8Array(bits.buffer);
+ }
+
+ /**
+ * Copies the scalar value to the Uint8Array buffer at the provided byte offset.
+ * @param buffer the destination buffer
+ * @param offset the byte offset within buffer
+ */
+ public copyTo(buffer: Uint8Array, offset: number) {
+ for (let i = 0; i < this.bits.length; i++) {
+ buffer[offset + i] = this.bits[i];
+ }
+ }
+
+ /**
+ * @returns the WGSL representation of this scalar value
+ */
+ public wgsl(): string {
+ const withPoint = (x: number) => {
+ const str = `${x}`;
+ return str.indexOf('.') > 0 || str.indexOf('e') > 0 ? str : `${str}.0`;
+ };
+ if (isFinite(this.value as number)) {
+ switch (this.type.kind) {
+ case 'f32':
+ return `${withPoint(this.value as number)}f`;
+ case 'f16':
+ return `${withPoint(this.value as number)}h`;
+ case 'u32':
+ return `${this.value}u`;
+ case 'i32':
+ return `i32(${this.value})`;
+ case 'bool':
+ return `${this.value}`;
+ }
+ }
+ throw new Error(
+ `scalar of value ${this.value} and type ${this.type} has no WGSL representation`
+ );
+ }
+
+ public toString(): string {
+ if (this.type.kind === 'bool') {
+ return Colors.bold(this.value.toString());
+ }
+ switch (this.value) {
+ case Infinity:
+ case -Infinity:
+ return Colors.bold(this.value.toString());
+ default: {
+ // Uint8Array.map returns a Uint8Array, so cannot use .map directly
+ const hex = Array.from(this.bits)
+ .reverse()
+ .map(x => x.toString(16).padStart(2, '0'))
+ .join('');
+ const n = this.value as Number;
+ if (n !== null && isFloatValue(this)) {
+ let str = this.value.toString();
+ str = str.indexOf('.') > 0 || str.indexOf('e') > 0 ? str : `${str}.0`;
+ return isSubnormalNumberF32(n.valueOf())
+ ? `${Colors.bold(str)} (0x${hex} subnormal)`
+ : `${Colors.bold(str)} (0x${hex})`;
+ }
+ return `${Colors.bold(this.value.toString())} (0x${hex})`;
+ }
+ }
+ }
+}
+
+/** Create an f64 from a numeric value, a JS `number`. */
+export function f64(value: number): Scalar {
+ const arr = new Float64Array([value]);
+ return new Scalar(TypeF64, arr[0], arr);
+}
+/** Create an f32 from a numeric value, a JS `number`. */
+export function f32(value: number): Scalar {
+ const arr = new Float32Array([value]);
+ return new Scalar(TypeF32, arr[0], arr);
+}
+/** Create an f16 from a numeric value, a JS `number`. */
+export function f16(value: number): Scalar {
+ const arr = new Float16Array([value]);
+ return new Scalar(TypeF16, arr[0], arr);
+}
+/** Create an f32 from a bit representation, a uint32 represented as a JS `number`. */
+export function f32Bits(bits: number): Scalar {
+ const arr = new Uint32Array([bits]);
+ return new Scalar(TypeF32, new Float32Array(arr.buffer)[0], arr);
+}
+/** Create an f16 from a bit representation, a uint16 represented as a JS `number`. */
+export function f16Bits(bits: number): Scalar {
+ const arr = new Uint16Array([bits]);
+ return new Scalar(TypeF16, new Float16Array(arr.buffer)[0], arr);
+}
+
+/** Create an i32 from a numeric value, a JS `number`. */
+export function i32(value: number): Scalar {
+ const arr = new Int32Array([value]);
+ return new Scalar(TypeI32, arr[0], arr);
+}
+/** Create an i16 from a numeric value, a JS `number`. */
+export function i16(value: number): Scalar {
+ const arr = new Int16Array([value]);
+ return new Scalar(TypeI16, arr[0], arr);
+}
+/** Create an i8 from a numeric value, a JS `number`. */
+export function i8(value: number): Scalar {
+ const arr = new Int8Array([value]);
+ return new Scalar(TypeI8, arr[0], arr);
+}
+
+/** Create an i32 from a bit representation, a uint32 represented as a JS `number`. */
+export function i32Bits(bits: number): Scalar {
+ const arr = new Uint32Array([bits]);
+ return new Scalar(TypeI32, new Int32Array(arr.buffer)[0], arr);
+}
+/** Create an i16 from a bit representation, a uint16 represented as a JS `number`. */
+export function i16Bits(bits: number): Scalar {
+ const arr = new Uint16Array([bits]);
+ return new Scalar(TypeI16, new Int16Array(arr.buffer)[0], arr);
+}
+/** Create an i8 from a bit representation, a uint8 represented as a JS `number`. */
+export function i8Bits(bits: number): Scalar {
+ const arr = new Uint8Array([bits]);
+ return new Scalar(TypeI8, new Int8Array(arr.buffer)[0], arr);
+}
+
+/** Create a u32 from a numeric value, a JS `number`. */
+export function u32(value: number): Scalar {
+ const arr = new Uint32Array([value]);
+ return new Scalar(TypeU32, arr[0], arr);
+}
+/** Create a u16 from a numeric value, a JS `number`. */
+export function u16(value: number): Scalar {
+ const arr = new Uint16Array([value]);
+ return new Scalar(TypeU16, arr[0], arr);
+}
+/** Create a u8 from a numeric value, a JS `number`. */
+export function u8(value: number): Scalar {
+ const arr = new Uint8Array([value]);
+ return new Scalar(TypeU8, arr[0], arr);
+}
+
+/** Create an u32 from a bit representation, a uint32 represented as a JS `number`. */
+export function u32Bits(bits: number): Scalar {
+ const arr = new Uint32Array([bits]);
+ return new Scalar(TypeU32, bits, arr);
+}
+/** Create an u16 from a bit representation, a uint16 represented as a JS `number`. */
+export function u16Bits(bits: number): Scalar {
+ const arr = new Uint16Array([bits]);
+ return new Scalar(TypeU16, bits, arr);
+}
+/** Create an u8 from a bit representation, a uint8 represented as a JS `number`. */
+export function u8Bits(bits: number): Scalar {
+ const arr = new Uint8Array([bits]);
+ return new Scalar(TypeU8, bits, arr);
+}
+
+/** Create a boolean value. */
+export function bool(value: boolean): Scalar {
+ // WGSL does not support using 'bool' types directly in storage / uniform
+ // buffers, so instead we pack booleans in a u32, where 'false' is zero and
+ // 'true' is any non-zero value.
+ const arr = new Uint32Array([value ? 1 : 0]);
+ return new Scalar(TypeBool, value, arr);
+}
+
+/** A 'true' literal value */
+export const True = bool(true);
+
+/** A 'false' literal value */
+export const False = bool(false);
+
+export function reinterpretF32AsU32(f32: number): number {
+ const array = new Float32Array(1);
+ array[0] = f32;
+ return new Uint32Array(array.buffer)[0];
+}
+
+export function reinterpretU32AsF32(u32: number): number {
+ const array = new Uint32Array(1);
+ array[0] = u32;
+ return new Float32Array(array.buffer)[0];
+}
+
+/**
+ * Class that encapsulates a vector value.
+ */
+export class Vector {
+ readonly elements: Array<Scalar>;
+ readonly type: VectorType;
+
+ public constructor(elements: Array<Scalar>) {
+ if (elements.length < 2 || elements.length > 4) {
+ throw new Error(`vector element count must be between 2 and 4, got ${elements.length}`);
+ }
+ for (let i = 1; i < elements.length; i++) {
+ const a = elements[0].type;
+ const b = elements[i].type;
+ if (a !== b) {
+ throw new Error(
+ `cannot mix vector element types. Found elements with types '${a}' and '${b}'`
+ );
+ }
+ }
+ this.elements = elements;
+ this.type = TypeVec(elements.length, elements[0].type);
+ }
+
+ /**
+ * Copies the vector value to the Uint8Array buffer at the provided byte offset.
+ * @param buffer the destination buffer
+ * @param offset the byte offset within buffer
+ */
+ public copyTo(buffer: Uint8Array, offset: number) {
+ for (const element of this.elements) {
+ element.copyTo(buffer, offset);
+ offset += this.type.elementType.size;
+ }
+ }
+
+ /**
+ * @returns the WGSL representation of this vector value
+ */
+ public wgsl(): string {
+ const els = this.elements.map(v => v.wgsl()).join(', ');
+ return `vec${this.type.width}(${els})`;
+ }
+
+ public toString(): string {
+ return `${this.type}(${this.elements.map(e => e.toString()).join(', ')})`;
+ }
+
+ public get x() {
+ assert(0 < this.elements.length);
+ return this.elements[0];
+ }
+
+ public get y() {
+ assert(1 < this.elements.length);
+ return this.elements[1];
+ }
+
+ public get z() {
+ assert(2 < this.elements.length);
+ return this.elements[2];
+ }
+
+ public get w() {
+ assert(3 < this.elements.length);
+ return this.elements[3];
+ }
+}
+
+/** Helper for constructing a new two-element vector with the provided values */
+export function vec2(x: Scalar, y: Scalar) {
+ return new Vector([x, y]);
+}
+
+/** Helper for constructing a new three-element vector with the provided values */
+export function vec3(x: Scalar, y: Scalar, z: Scalar) {
+ return new Vector([x, y, z]);
+}
+
+/** Helper for constructing a new four-element vector with the provided values */
+export function vec4(x: Scalar, y: Scalar, z: Scalar, w: Scalar) {
+ return new Vector([x, y, z, w]);
+}
+
+/**
+ * Helper for constructing Vectors from arrays of numbers
+ *
+ * @param v array of numbers to be converted, must contain 2, 3 or 4 elements
+ * @param op function to convert from number to Scalar, e.g. 'f32`
+ */
+export function toVector(v: number[], op: (n: number) => Scalar): Vector {
+ switch (v.length) {
+ case 2:
+ return vec2(op(v[0]), op(v[1]));
+ case 3:
+ return vec3(op(v[0]), op(v[1]), op(v[2]));
+ case 4:
+ return vec4(op(v[0]), op(v[1]), op(v[2]), op(v[3]));
+ }
+ unreachable(`input to 'toVector' must contain 2, 3, or 4 elements`);
+}
+
+/** Value is a Scalar or Vector value. */
+export type Value = Scalar | Vector;
+
+export type SerializedValueScalar = {
+ kind: 'scalar';
+ type: ScalarKind;
+ value: boolean | number;
+};
+
+export type SerializedValueVector = {
+ kind: 'vector';
+ type: ScalarKind;
+ value: boolean[] | number[];
+};
+
+export type SerializedValue = SerializedValueScalar | SerializedValueVector;
+
+export function serializeValue(v: Value): SerializedValue {
+ const value = (kind: ScalarKind, s: Scalar) => {
+ switch (kind) {
+ case 'f32':
+ return new Uint32Array(s.bits.buffer)[0];
+ case 'f16':
+ return new Uint16Array(s.bits.buffer)[0];
+ default:
+ return s.value;
+ }
+ };
+ if (v instanceof Scalar) {
+ const kind = v.type.kind;
+ return {
+ kind: 'scalar',
+ type: kind,
+ value: value(kind, v),
+ };
+ }
+ if (v instanceof Vector) {
+ const kind = v.type.elementType.kind;
+ return {
+ kind: 'vector',
+ type: kind,
+ value: v.elements.map(e => value(kind, e)) as boolean[] | number[],
+ };
+ }
+ unreachable(`unhandled value type: ${v}`);
+}
+
+export function deserializeValue(data: SerializedValue): Value {
+ const buildScalar = (v: ScalarValue): Scalar => {
+ switch (data.type) {
+ case 'f64':
+ return f64(v as number);
+ case 'i32':
+ return i32(v as number);
+ case 'u32':
+ return u32(v as number);
+ case 'f32':
+ return f32Bits(v as number);
+ case 'i16':
+ return i16(v as number);
+ case 'u16':
+ return u16(v as number);
+ case 'f16':
+ return f16Bits(v as number);
+ case 'i8':
+ return i8(v as number);
+ case 'u8':
+ return u8(v as number);
+ case 'bool':
+ return bool(v as boolean);
+ default:
+ unreachable(`unhandled value type: ${data.type}`);
+ }
+ };
+ switch (data.kind) {
+ case 'scalar': {
+ return buildScalar(data.value);
+ }
+ case 'vector': {
+ return new Vector(data.value.map(v => buildScalar(v)));
+ }
+ }
+}
+
+/** @returns if the Value is a float scalar type */
+export function isFloatValue(v: Value): boolean {
+ if (v instanceof Scalar) {
+ const s = v;
+ return s.type.kind === 'f64' || s.type.kind === 'f32' || s.type.kind === 'f16';
+ }
+ return false;
+}