From db86d8e0e7353770e5e14fc52af5b02ffb72df20 Mon Sep 17 00:00:00 2001 From: Romain Guy Date: Tue, 20 Jul 2021 17:05:35 -0700 Subject: [PATCH] Optimize multiplication operators to avoid allocations --- .../com/curiouscreature/kotlin/math/Matrix.kt | 105 ++++++++++++------ .../curiouscreature/kotlin/math/MatrixTest.kt | 46 ++++++++ 2 files changed, 115 insertions(+), 36 deletions(-) diff --git a/src/main/kotlin/com/curiouscreature/kotlin/math/Matrix.kt b/src/main/kotlin/com/curiouscreature/kotlin/math/Matrix.kt index f7c130e..af87fdf 100644 --- a/src/main/kotlin/com/curiouscreature/kotlin/math/Matrix.kt +++ b/src/main/kotlin/com/curiouscreature/kotlin/math/Matrix.kt @@ -82,18 +82,22 @@ data class Mat2( operator fun times(v: Float) = Mat2(x * v, y * v) operator fun div(v: Float) = Mat2(x / v, y / v) - operator fun times(m: Mat2): Mat2 { - val t = transpose(this) - return Mat2( - Float2(dot(t.x, m.x), dot(t.y, m.x)), - Float2(dot(t.x, m.y), dot(t.y, m.y)) - ) - } + operator fun times(m: Mat2) = Mat2( + Float2( + x.x * m.x.x + y.x * m.x.y, + x.y * m.x.x + y.y * m.x.y, + ), + Float2( + x.x * m.y.x + y.x * m.y.y, + x.y * m.y.x + y.y * m.y.y, + ) + ) + + operator fun times(v: Float2) = Float2( + x.x * v.x + y.x * v.y, + x.y * v.x + y.y * v.y, + ) - operator fun times(v: Float2): Float2 { - val t = transpose(this) - return Float2(dot(t.x, v), dot(t.y, v)) - } fun toFloatArray() = floatArrayOf( x.x, y.x, @@ -173,19 +177,29 @@ data class Mat3( operator fun times(v: Float) = Mat3(x * v, y * v, z * v) operator fun div(v: Float) = Mat3(x / v, y / v, z / v) - operator fun times(m: Mat3): Mat3 { - val t = transpose(this) - return Mat3( - Float3(dot(t.x, m.x), dot(t.y, m.x), dot(t.z, m.x)), - Float3(dot(t.x, m.y), dot(t.y, m.y), dot(t.z, m.y)), - Float3(dot(t.x, m.z), dot(t.y, m.z), dot(t.z, m.z)) - ) - } + operator fun times(m: Mat3) = Mat3( + Float3( + x.x * m.x.x + y.x * m.x.y + z.x * m.x.z, + x.y * m.x.x + y.y * m.x.y + z.y * m.x.z, + x.z * m.x.x + y.z * m.x.y + z.z * m.x.z, + ), + Float3( + x.x * m.y.x + y.x * m.y.y + z.x * m.y.z, + x.y * m.y.x + y.y * m.y.y + z.y * m.y.z, + x.z * m.y.x + y.z * m.y.y + z.z * m.y.z, + ), + Float3( + x.x * m.z.x + y.x * m.z.y + z.x * m.z.z, + x.y * m.z.x + y.y * m.z.y + z.y * m.z.z, + x.z * m.z.x + y.z * m.z.y + z.z * m.z.z, + ) + ) - operator fun times(v: Float3): Float3 { - val t = transpose(this) - return Float3(dot(t.x, v), dot(t.y, v), dot(t.z, v)) - } + operator fun times(v: Float3) = Float3( + x.x * v.x + y.x * v.y + z.x * v.z, + x.y * v.x + y.y * v.y + z.y * v.z, + x.z * v.x + y.z * v.y + z.z * v.z, + ) fun toFloatArray() = floatArrayOf( x.x, y.x, z.x, @@ -315,20 +329,39 @@ data class Mat4( operator fun times(v: Float) = Mat4(x * v, y * v, z * v, w * v) operator fun div(v: Float) = Mat4(x / v, y / v, z / v, w / v) - operator fun times(m: Mat4): Mat4 { - val t = transpose(this) - return Mat4( - Float4(dot(t.x, m.x), dot(t.y, m.x), dot(t.z, m.x), dot(t.w, m.x)), - Float4(dot(t.x, m.y), dot(t.y, m.y), dot(t.z, m.y), dot(t.w, m.y)), - Float4(dot(t.x, m.z), dot(t.y, m.z), dot(t.z, m.z), dot(t.w, m.z)), - Float4(dot(t.x, m.w), dot(t.y, m.w), dot(t.z, m.w), dot(t.w, m.w)) - ) - } + operator fun times(m: Mat4) = Mat4( + Float4( + x.x * m.x.x + y.x * m.x.y + z.x * m.x.z + w.x * m.x.w, + x.y * m.x.x + y.y * m.x.y + z.y * m.x.z + w.y * m.x.w, + x.z * m.x.x + y.z * m.x.y + z.z * m.x.z + w.z * m.x.w, + x.w * m.x.x + y.w * m.x.y + z.w * m.x.z + w.w * m.x.w, + ), + Float4( + x.x * m.y.x + y.x * m.y.y + z.x * m.y.z + w.x * m.y.w, + x.y * m.y.x + y.y * m.y.y + z.y * m.y.z + w.y * m.y.w, + x.z * m.y.x + y.z * m.y.y + z.z * m.y.z + w.z * m.y.w, + x.w * m.y.x + y.w * m.y.y + z.w * m.y.z + w.w * m.y.w, + ), + Float4( + x.x * m.z.x + y.x * m.z.y + z.x * m.z.z + w.x * m.z.w, + x.y * m.z.x + y.y * m.z.y + z.y * m.z.z + w.y * m.z.w, + x.z * m.z.x + y.z * m.z.y + z.z * m.z.z + w.z * m.z.w, + x.w * m.z.x + y.w * m.z.y + z.w * m.z.z + w.w * m.z.w, + ), + Float4( + x.x * m.w.x + y.x * m.w.y + z.x * m.w.z + w.x * m.w.w, + x.y * m.w.x + y.y * m.w.y + z.y * m.w.z + w.y * m.w.w, + x.z * m.w.x + y.z * m.w.y + z.z * m.w.z + w.z * m.w.w, + x.w * m.w.x + y.w * m.w.y + z.w * m.w.z + w.w * m.w.w, + ) + ) - operator fun times(v: Float4): Float4 { - val t = transpose(this) - return Float4(dot(t.x, v), dot(t.y, v), dot(t.z, v), dot(t.w, v)) - } + operator fun times(v: Float4) = Float4( + x.x * v.x + y.x * v.y + z.x * v.z+ w.x * v.w, + x.y * v.x + y.y * v.y + z.y * v.z+ w.y * v.w, + x.z * v.x + y.z * v.y + z.z * v.z+ w.z * v.w, + x.w * v.x + y.w * v.y + z.w * v.z+ w.w * v.w + ) fun toFloatArray() = floatArrayOf( x.x, y.x, z.x, w.x, diff --git a/src/test/kotlin/com/curiouscreature/kotlin/math/MatrixTest.kt b/src/test/kotlin/com/curiouscreature/kotlin/math/MatrixTest.kt index 7337b01..6acf24e 100644 --- a/src/test/kotlin/com/curiouscreature/kotlin/math/MatrixTest.kt +++ b/src/test/kotlin/com/curiouscreature/kotlin/math/MatrixTest.kt @@ -19,8 +19,54 @@ package com.curiouscreature.kotlin.math import org.junit.Assert import org.junit.Test import kotlin.test.assertEquals +import kotlin.test.assertNotEquals class MatrixTest { + @Test + fun `Mat2 multiplication`() { + val a = Mat2.of(1.0f, 2.0f, 3.0f, 4.0f) + val b = Mat2.of(2.0f, 0.0f, 1.0f, 2.0f) + assertEquals( + Mat2.of(4.0f, 4.0f, 10.0f, 8.0f), + a * b + ) + } + + @Test + fun `Mat3 multiplication`() { + val a = Mat3.of(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f) + val b = Mat3.of(2.0f, 0.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 0.0f, 1.0f) + assertEquals( + Mat3.of(12.0f, 6.0f, 6.0f, 30.0f, 15.0f, 15.0f, 48.0f, 24.0f, 24.0f), + a * b + ) + } + + @Test + fun `Mat4 multiplication`() { + val a = Mat4.of( + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, + ) + val b = Mat4.of( + 2.0f, 0.0f, 1.0f, 2.0f, + 1.0f, 1.0f, 2.0f, 0.0f, + 2.0f, 1.0f, 2.0f, 2.0f, + 0.0f, 1.0f, 1.0f, 2.0f, + ) + assertEquals( + Mat4.of( + 10.0f, 9.0f, 15.0f, 16.0f, + 30.0f, 21.0f, 39.0f, 40.0f, + 50.0f, 33.0f, 63.0f, 64.0f, + 70.0f, 45.0f, 87.0f, 88.0f + ), + a * b + ) + } + @Test fun `Mat3 identity`() { assertEquals(