From 1a9d9918048b5768dd289d34a0413b154db765a0 Mon Sep 17 00:00:00 2001 From: Joshua Ashton Date: Fri, 12 Aug 2022 09:05:58 +0000 Subject: [PATCH] add matrix code --- include/Orange/Math/Matrix.h | 248 +++++++++++++++++++++++++++++++++++ src/Apps/Tools/CubeTest.cpp | 1 + 2 files changed, 249 insertions(+) create mode 100644 include/Orange/Math/Matrix.h diff --git a/include/Orange/Math/Matrix.h b/include/Orange/Math/Matrix.h new file mode 100644 index 0000000..afd4d23 --- /dev/null +++ b/include/Orange/Math/Matrix.h @@ -0,0 +1,248 @@ +#pragma once + +#include + +namespace orange +{ + template + struct Matrix + { + using RowVector = Vec; + + constexpr Matrix(T scale = T{ 1 }) + { + for (size_t i = 0; i < Rows; i++) + { + RowVector vector{}; + vector[i] = scale; + data[i] = vector; + } + } + + template + constexpr Matrix(const Args&... args) + : data {{ args... }} + { + static_assert(sizeof...(Args) == Rows); + } + + constexpr Matrix(const T components[Columns][Rows]) + { + std::copy(&components[0], &components[Columns][Rows], data.begin()); + } + + constexpr Matrix(const Matrix& other) = default; + + + constexpr RowVector& operator[](size_t index) { return data[index]; } + constexpr const RowVector& operator[](size_t index) const { return data[index]; } + + + constexpr const RowVector* begin() const { return data.begin(); } + constexpr RowVector* begin() { return data.begin(); } + constexpr const RowVector* end() const { return data.end(); } + constexpr RowVector* end() { return data.end(); } + + + constexpr bool operator==(const Matrix& other) const + { + return Equal(begin(), end(), other.begin()); + } + + constexpr bool operator!=(const Matrix& other) const + { + return !operator==(other); + } + + + template + constexpr Matrix TransformResult(UnaryOperation op) const + { + return TransformResult(begin(), end(), op); + } + + template + constexpr Matrix TransformResult(const RowVector *other, BinaryOperation op) const + { + return TransformResult(begin(), end(), other, op); + } + + template + constexpr Matrix TransformResult(const Matrix& other, BinaryOperation op) const + { + return TransformResult(other.begin(), op); + } + + template + constexpr Matrix& TransformInPlace(UnaryOperation op) + { + Transform(begin(), end(), begin(), op); + return *this; + } + + template + constexpr Matrix& TransformInPlace(const RowVector *other, BinaryOperation op) + { + Transform(begin(), end(), other, begin(), op); + return *this; + } + + template + constexpr Matrix& TransformInPlace(const Matrix& other, BinaryOperation op) + { + return TransformInPlace(other.begin(), op); + } + + // Simple math operations + + constexpr Matrix operator+(const Matrix& other) const + { + return TransformResult(other, math::Add); + } + + constexpr Matrix operator-(const Matrix& other) const + { + return TransformResult(other, math::Subtract); + } + + constexpr Matrix operator*(const T& scalar) const + { + return TransformResult([scalar](const RowVector& value) { return value * scalar; }); + } + + constexpr Matrix operator/(const T& scalar) const + { + return TransformResult([scalar](const RowVector& value) { return value / scalar; }); + } + + constexpr Matrix operator%(const T& scalar) const + { + return TransformResult([scalar](const RowVector& value) { return value % scalar; }); + } + + + constexpr Matrix& operator+=(const Matrix& other) + { + return TransformInPlace(math::Add); + } + + constexpr Matrix& operator-=(const Matrix& other) + { + return TransformInPlace(math::Subtract); + } + + constexpr Matrix& operator*=(const T& scalar) + { + return TransformInPlace([scalar](const RowVector& value) { return value * scalar; }); + } + + constexpr Matrix& operator/=(const T& scalar) + { + return TransformInPlace([scalar](const RowVector& value) { return value / scalar; }); + } + + constexpr Matrix& operator%=(const T& scalar) + { + return TransformInPlace([scalar](const RowVector& value) { return value % scalar; }); + } + + // Real matrix operations + + constexpr Matrix operator*(const Matrix& other) const + { + Matrix out; + for (size_t r = 0; r < Rows; r++) + { + for (size_t c = 0; c < Columns; c++) + out[r][c] = (*this)[r][c] * other[c][r]; + } + return out; + } + + constexpr Matrix& operator*=(const Matrix& other) const + { + return (*this = *this * other); + } + + constexpr RowVector operator*(const RowVector& v) const + { + auto mul = TransformResult(v.begin(), math::Multiply); + return Accumulate(mul.begin(), mul.end(), RowVector{}); + } + + Array data; + }; + + template + constexpr Matrix operator*(T scalar, const Matrix& matrix) + { + using J = Matrix::RowVector; + return matrix.TransformResult([scalar](J value) { return scalar * value; }); + } + + template + constexpr Matrix operator/(T scalar, const Matrix& matrix) + { + using J = Matrix::RowVector; + return matrix.TransformResult([scalar](J value) { return scalar / value; }); + } + + template + constexpr Matrix operator%(T scalar, const Matrix& matrix) + { + using J = Matrix::RowVector; + return matrix.TransformResult([scalar](J value) { return scalar % value; }); + } + + + template + constexpr Matrix minor(const Matrix& a, size_t column, size_t row) + { + Matrix mtx; + for (size_t y = 0, my = 0; y < Rows; y++) { + if (y == row) continue; + for (size_t x = 0, mx = 0; x < Columns; x++) { + if (x == column) continue; + mtx[my][mx] = a[y][x]; + mx++; + } + my++; + } + return mtx; + } + + template + constexpr T determinant(const Matrix& a) + { + static_assert(Rows == Columns); + + if constexpr (Rows == 1) { + return a[0][0]; + } else { + T result = T{}; + for (size_t x = 0; x < Columns; x++) { + const T sign = (x % 2) ? T{ -1 } : T{ 1 }; + result += sign * a[0][x] * determinant(minor(a, x, 0)); + } + return result; + } + } + + template + constexpr Matrix transpose(const Matrix& a) + { + Matrix result; + for (size_t y = 0; y < Rows; y++) { + for (size_t x = 0; x < Columns; x++) + result[x][y] = a[y][x]; + } + return result; + } + + template + constexpr Matrix hadamard(const Matrix& x, const Matrix& y) + { + return x.TransformResult(y, math::Multiply); + } + +} diff --git a/src/Apps/Tools/CubeTest.cpp b/src/Apps/Tools/CubeTest.cpp index 02cc9f9..5b3fa0e 100644 --- a/src/Apps/Tools/CubeTest.cpp +++ b/src/Apps/Tools/CubeTest.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include