From e199dccb9690976daac8642c4941e4cf7a518443 Mon Sep 17 00:00:00 2001 From: Lorenz Dirry Date: Thu, 26 Sep 2024 01:23:29 +0200 Subject: [PATCH 1/2] [DAPHNE-#822] Added MatrixMarket write support - Write support for CSRMatrix in coordinate system - Write support for DenseMatrix in array format - Support for recognizing symmetric, skew-symmetric matrices, and storing them accordingly in .mtx - Added testcases, calling readMM, then writeMM and then comparing if the files are still the same --- src/runtime/local/io/WriteMM.h | 239 +++++++++++++++++++++++++ src/runtime/local/kernels/Write.h | 6 +- test/CMakeLists.txt | 1 + test/api/cli/io/WriteTest.cpp | 67 +++++++ test/api/cli/io/readAndWriteMtx.daphne | 3 + test/runtime/local/io/WriteMMTest.cpp | 209 +++++++++++++++++++++ test/runtime/local/io/aig.mtx.meta | 6 + test/runtime/local/io/aik.mtx.meta | 6 + test/runtime/local/io/ais.mtx.meta | 6 + test/runtime/local/io/cig.mtx | 3 - 10 files changed, 542 insertions(+), 4 deletions(-) create mode 100644 src/runtime/local/io/WriteMM.h create mode 100644 test/api/cli/io/readAndWriteMtx.daphne create mode 100644 test/runtime/local/io/WriteMMTest.cpp create mode 100644 test/runtime/local/io/aig.mtx.meta create mode 100644 test/runtime/local/io/aik.mtx.meta create mode 100644 test/runtime/local/io/ais.mtx.meta diff --git a/src/runtime/local/io/WriteMM.h b/src/runtime/local/io/WriteMM.h new file mode 100644 index 000000000..d9fa23148 --- /dev/null +++ b/src/runtime/local/io/WriteMM.h @@ -0,0 +1,239 @@ +#ifndef WRITE_MM_H +#define WRITE_MM_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +template struct WriteMM { + static void apply(const DTArg *arg, const char *filename) = delete; +}; + +// Convenience function +template void writeMM(const DTArg *arg, const char *filename) { WriteMM::apply(arg, filename); } + +// ---------------------------------------------------------------------------- +// DenseMatrix +// ---------------------------------------------------------------------------- + +template struct WriteMM> { + static void apply(const DenseMatrix *arg, const char *filename) { + const char *format = MM_DENSE_STR; + std::ofstream f(filename); + if (!f.is_open()) { + throw std::runtime_error("WriteMM::apply: Cannot open file"); + } + + const char *field; + if (std::is_integral::value) { + field = MM_INT_STR; + } else if (std::is_floating_point::value) { + field = MM_REAL_STR; + } else { + throw std::runtime_error("WriteMM::apply: Unsupported data type"); + } + + const char *symmetry = MM_GENERAL_STR; + if (isSymmetric(arg)) { + symmetry = MM_SYMM_STR; + } else if (isSkewSymmetric(arg)) { + symmetry = MM_SKEW_STR; + } + + f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; + + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + f << rows << " " << cols << std::endl; + + const VT *values = arg->getValues(); + if (!values) { + throw std::runtime_error("WriteMM::apply: Null pointer for 'values' in DenseMatrix"); + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = 0; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else if (strcmp(symmetry, MM_SYMM_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else if (strcmp(symmetry, MM_SKEW_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + f.close(); + } + + private: + static bool isSymmetric(const DenseMatrix *arg) { + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + if (rows != cols) + return false; + const VT *values = arg->getValues(); + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx1 = j + i * rows; + size_t idx2 = i + j * rows; + if (values[idx1] != values[idx2]) + return false; + } + } + return true; + } + + static bool isSkewSymmetric(const DenseMatrix *arg) { + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + if (rows != cols) + return false; + const VT *values = arg->getValues(); + for (size_t i = 0; i < rows; ++i) { + size_t idx_diag = i + i * rows; + if (values[idx_diag] != 0) + return false; + } + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx1 = j + i * rows; + size_t idx2 = i + j * rows; + if (values[idx1] != -values[idx2]) + return false; + } + } + return true; + } +}; + +// ---------------------------------------------------------------------------- +// CSRMatrix +// ---------------------------------------------------------------------------- + +template struct WriteMM> { + static void apply(const CSRMatrix *arg, const char *filename) { + const char *format = MM_SPARSE_STR; + std::ofstream f(filename); + if (!f.is_open()) { + throw std::runtime_error("WriteMM::apply: Cannot open file"); + } + + const char *field; + if (std::is_integral::value) { + field = MM_INT_STR; + } else if (std::is_floating_point::value) { + field = MM_REAL_STR; + } else { + throw std::runtime_error("WriteMM::apply: Unsupported data type"); + } + + const char *symmetry = MM_GENERAL_STR; + + f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; + + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + size_t nnz = countNNZ(arg, symmetry); + + f << rows << " " << cols << " " << nnz << std::endl; + + const size_t *rowOffsets = arg->getRowOffsets(); + const size_t *colIdxs = arg->getColIdxs(); + const VT *values = arg->getValues(); + + std::vector>> colEntries(cols); + + for (size_t i = 0; i < rows; ++i) { + for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { + size_t j = colIdxs[idx]; + VT val = values[idx]; + colEntries[j].emplace_back(i, val); + } + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + for (size_t j = 0; j < cols; ++j) { + for (const auto &entry : colEntries[j]) { + size_t i = entry.first; + VT val = entry.second; + if (strcmp(field, MM_REAL_STR) == 0) { + if (val >= 0) { + f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val + << std::endl; + } else { + f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val + << std::endl; + } + } else { + f << i + 1 << " " << j + 1 << " " << val << std::endl; + } + } + } + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + + f.close(); + } + + private: + static size_t countNNZ(const CSRMatrix *arg, const char *symmetry) { + size_t nnz = 0; + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + + std::vector>> colEntries(cols); + + const size_t *rowOffsets = arg->getRowOffsets(); + const size_t *colIdxs = arg->getColIdxs(); + const VT *values = arg->getValues(); + + for (size_t i = 0; i < rows; ++i) { + for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { + size_t j = colIdxs[idx]; + colEntries[j].emplace_back(i, values[idx]); + } + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + nnz = arg->getNumNonZeros(); + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + + return nnz; + } +}; + +#endif // WRITE_MM_H diff --git a/src/runtime/local/kernels/Write.h b/src/runtime/local/kernels/Write.h index 40f6014a3..85a149be0 100644 --- a/src/runtime/local/kernels/Write.h +++ b/src/runtime/local/kernels/Write.h @@ -26,6 +26,7 @@ #include #include #include +#include #if USE_HDFS #include #endif @@ -82,10 +83,11 @@ template struct Write> { // call WriteHDFS writeHDFS(arg, filename, ctx); #endif + } else if (ext == "mtx") { + writeMM(arg, filename); } } }; - // ---------------------------------------------------------------------------- // Frame // ---------------------------------------------------------------------------- @@ -121,6 +123,8 @@ template struct Write> { MetaDataParser::writeMetaData(filename, metaData); writeCsv(arg, file); closeFile(file); + } else if (ext == "mtx") { + writeMM(arg, filename); // Write Matrix in MatrixMarket format } else { throw std::runtime_error("[Write.h] - generic Matrix type currently only supports csv " "file extension."); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fad31d46d..e6103c19b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -82,6 +82,7 @@ set(TEST_SOURCES runtime/local/io/ReadCsvTest.cpp runtime/local/io/ReadParquetTest.cpp runtime/local/io/ReadMMTest.cpp + runtime/local/io/WriteMMTest.cpp runtime/local/io/WriteDaphneTest.cpp runtime/local/io/ReadDaphneTest.cpp runtime/local/io/DaphneSerializerTest.cpp diff --git a/test/api/cli/io/WriteTest.cpp b/test/api/cli/io/WriteTest.cpp index b58b8ec09..0800945db 100644 --- a/test/api/cli/io/WriteTest.cpp +++ b/test/api/cli/io/WriteTest.cpp @@ -21,9 +21,49 @@ #include #include +#include +#include #include const std::string dirPath = "test/api/cli/io/"; +const std::string dirPath2 = "test/runtime/local/io/"; + +bool compareFiles(const std::string &filePath1, const std::string &filePath2) { + + std::ifstream file1(filePath1, std::ios::binary); + std::ifstream file2(filePath2, std::ios::binary); + + if (!file1.is_open() || !file2.is_open()) { + std::cerr << "Cannot open one or both files." << std::endl; + return false; + } + + std::string line1, line2; + bool filesAreEqual = true; + + while (std::getline(file1, line1)) { + if (!std::getline(file2, line2)) { + filesAreEqual = false; + break; + } + + if (line1 != line2) { + filesAreEqual = false; + break; + } + } + + if (filesAreEqual && std::getline(file2, line2)) { + if (!line2.empty()) { + filesAreEqual = false; + } + } + + file1.close(); + file2.close(); + + return filesAreEqual; +} TEST_CASE("writeMatrixCSV_Full", TAG_IO) { std::string csvPath = dirPath + "matrix_full.csv"; @@ -41,4 +81,31 @@ TEST_CASE("writeMatrixCSV_View", TAG_IO) { std::string("outPath=\"" + csvPath + "\"").c_str()); compareDaphneToRef(dirPath + "matrix_view_ref.csv", dirPath + "readMatrix.daphne", "--args", std::string("inPath=\"" + csvPath + "\"").c_str()); +} + +TEST_CASE("writeMatrixMtxaig", TAG_IO) { + std::string expectedPath = dirPath2 + "aig.mtx"; + std::string resultPath = "out.mtx"; + checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", + std::string("inPath=\"" + expectedPath + "\"").c_str()); + CHECK(compareFiles(expectedPath, resultPath)); + std::filesystem::remove(resultPath); // remove old file if it still exists +} + +TEST_CASE("writeMatrixMtxaik", TAG_IO) { + std::string expectedPath = dirPath2 + "aik.mtx"; + std::string resultPath = "out.mtx"; + checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", + std::string("inPath=\"" + expectedPath + "\"").c_str()); + CHECK(compareFiles(expectedPath, resultPath)); + std::filesystem::remove(resultPath); // remove old file if it still exists +} + +TEST_CASE("writeMatrixMtxais", TAG_IO) { + std::string expectedPath = dirPath2 + "ais.mtx"; + std::string resultPath = "out.mtx"; + checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", + std::string("inPath=\"" + expectedPath + "\"").c_str()); + CHECK(compareFiles(expectedPath, resultPath)); + std::filesystem::remove(resultPath); // remove old file if it still exists } \ No newline at end of file diff --git a/test/api/cli/io/readAndWriteMtx.daphne b/test/api/cli/io/readAndWriteMtx.daphne new file mode 100644 index 000000000..590559c4f --- /dev/null +++ b/test/api/cli/io/readAndWriteMtx.daphne @@ -0,0 +1,3 @@ +X = readMatrix($inPath); +print(X); +write(X, "out.mtx"); \ No newline at end of file diff --git a/test/runtime/local/io/WriteMMTest.cpp b/test/runtime/local/io/WriteMMTest.cpp new file mode 100644 index 000000000..4d02edfa3 --- /dev/null +++ b/test/runtime/local/io/WriteMMTest.cpp @@ -0,0 +1,209 @@ +/* + * Copyright 2022 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +bool compareContentsFromFile(const std::string &filePath1, const std::string &filePath2) { + std::ifstream file1(filePath1, std::ios::binary); + std::ifstream file2(filePath2, std::ios::binary); + + if (!file1.is_open() || !file2.is_open()) { + std::cerr << "Cannot open one or both files." << std::endl; + return false; + } + + std::string line1, line2; + bool filesAreEqual = true; + + while (std::getline(file1, line1)) { + if (!std::getline(file2, line2)) { + filesAreEqual = false; + break; + } + + if (line1 != line2) { + filesAreEqual = false; + break; + } + } + + if (filesAreEqual && std::getline(file2, line2)) { + if (!line2.empty()) { + filesAreEqual = false; + } + } + + file1.close(); + file2.close(); + + return filesAreEqual; +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIG", TAG_IO, (DenseMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 4; + size_t numCols = 3; + + char filename[] = "./test/runtime/local/io/aig.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(0, 0) == 1); + CHECK(m->get(1, 0) == 2); + CHECK(m->get(0, 1) == 5); + CHECK(m->get(3, 2) == 12); + CHECK(m->get(2, 1) == 7); + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIK", TAG_IO, (DenseMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 4; + size_t numCols = 4; + + char filename[] = "./test/runtime/local/io/aik.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(1, 0) == 1); + + for (size_t r = 0; r < numRows; r++) { + CHECK(m->get(r, r) == 0); + for (size_t c = r + 1; c < numCols; c++) + CHECK(m->get(r, c) == -m->get(c, r)); + } + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIS", TAG_IO, (DenseMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 3; + size_t numCols = 3; + + char filename[] = "./test/runtime/local/io/ais.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(1, 1) == 4); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = r + 1; c < numCols; c++) + CHECK(m->get(r, c) == m->get(c, r)); + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM CIG (CSR)", TAG_IO, (CSRMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 9; + size_t numCols = 9; + + char filename[] = "./test/runtime/local/io/cig.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(0, 0) == 1); + CHECK(m->get(2, 0) == 0); + CHECK(m->get(3, 4) == 9); + CHECK(m->get(7, 4) == 4); + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM CRG (CSR)", TAG_IO, (CSRMatrix), (double)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 497; + size_t numCols = 507; + + char filename[] = "./test/runtime/local/io/crg.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(5, 0) == 0.25599762); + CHECK(m->get(6, 0) == 0.13827993); + CHECK(m->get(200, 4) == 0.20001954); + + DataObjectFactory::destroy(m); +} \ No newline at end of file diff --git a/test/runtime/local/io/aig.mtx.meta b/test/runtime/local/io/aig.mtx.meta new file mode 100644 index 000000000..f84a4ff4f --- /dev/null +++ b/test/runtime/local/io/aig.mtx.meta @@ -0,0 +1,6 @@ +{ + "numRows": 4, + "numCols": 3, + "valueType": "si64", + "numNonZeros": 12 +} diff --git a/test/runtime/local/io/aik.mtx.meta b/test/runtime/local/io/aik.mtx.meta new file mode 100644 index 000000000..b04ed1b53 --- /dev/null +++ b/test/runtime/local/io/aik.mtx.meta @@ -0,0 +1,6 @@ +{ + "numRows": 4, + "numCols": 4, + "valueType": "si64", + "numNonZeros": 16 +} diff --git a/test/runtime/local/io/ais.mtx.meta b/test/runtime/local/io/ais.mtx.meta new file mode 100644 index 000000000..43bfe32f6 --- /dev/null +++ b/test/runtime/local/io/ais.mtx.meta @@ -0,0 +1,6 @@ +{ + "numRows": 3, + "numCols": 3, + "valueType": "si64", + "numNonZeros": 9 +} diff --git a/test/runtime/local/io/cig.mtx b/test/runtime/local/io/cig.mtx index f12363fd9..f30c9d439 100644 --- a/test/runtime/local/io/cig.mtx +++ b/test/runtime/local/io/cig.mtx @@ -1,7 +1,4 @@ %%MatrixMarket matrix coordinate integer general -% -% -% 1 2 3 9 9 50 1 1 1 2 1 2 From c150a45d9f3e2bc219b2902167962c0886ce2390 Mon Sep 17 00:00:00 2001 From: Lorenz Dirry Date: Mon, 21 Oct 2024 13:17:15 +0200 Subject: [PATCH 2/2] comment fix --- src/runtime/local/io/WriteMM.cpp | 191 +++++++++++++++++++++++ src/runtime/local/io/WriteMM.h | 214 +------------------------- test/api/cli/Utils.cpp | 29 ++++ test/api/cli/Utils.h | 15 ++ test/api/cli/io/WriteTest.cpp | 45 +----- test/runtime/local/io/WriteMMTest.cpp | 48 +----- 6 files changed, 249 insertions(+), 293 deletions(-) create mode 100644 src/runtime/local/io/WriteMM.cpp diff --git a/src/runtime/local/io/WriteMM.cpp b/src/runtime/local/io/WriteMM.cpp new file mode 100644 index 000000000..5365106be --- /dev/null +++ b/src/runtime/local/io/WriteMM.cpp @@ -0,0 +1,191 @@ +#include "write_mm.h" + +template +void writeMM(const DTArg *arg, const char *filename) { + WriteMM::apply(arg, filename); +} + +template +void WriteMM>::apply(const DenseMatrix *arg, const char *filename) { + const char *format = MM_DENSE_STR; + std::ofstream f(filename); + if (!f.is_open()) { + throw std::runtime_error("WriteMM::apply: Cannot open file"); + } + + const char *field; + if (std::is_integral::value) { + field = MM_INT_STR; + } else if (std::is_floating_point::value) { + field = MM_REAL_STR; + } else { + throw std::runtime_error("WriteMM::apply: Unsupported data type"); + } + + const char *symmetry = MM_GENERAL_STR; + if (isSymmetric(arg)) { + symmetry = MM_SYMM_STR; + } else if (isSkewSymmetric(arg)) { + symmetry = MM_SKEW_STR; + } + + f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; + + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + f << rows << " " << cols << std::endl; + + const VT *values = arg->getValues(); + if (!values) { + throw std::runtime_error("WriteMM::apply: Null pointer for 'values' in DenseMatrix"); + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = 0; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else if (strcmp(symmetry, MM_SYMM_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else if (strcmp(symmetry, MM_SKEW_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + f.close(); +} + +template +bool WriteMM>::isSymmetric(const DenseMatrix *arg) { + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + if (rows != cols) + return false; + const VT *values = arg->getValues(); + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx1 = j + i * rows; + size_t idx2 = i + j * rows; + if (values[idx1] != values[idx2]) + return false; + } + } + return true; +} + +template +bool WriteMM>::isSkewSymmetric(const DenseMatrix *arg) { + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + if (rows != cols) + return false; + const VT *values = arg->getValues(); + for (size_t i = 0; i < rows; ++i) { + size_t idx_diag = i + i * rows; + if (values[idx_diag] != 0) + return false; + } + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx1 = j + i * rows; + size_t idx2 = i + j * rows; + if (values[idx1] != -values[idx2]) + return false; + } + } + return true; +} + +template +void WriteMM>::apply(const CSRMatrix *arg, const char *filename) { + const char *format = MM_SPARSE_STR; + std::ofstream f(filename); + if (!f.is_open()) { + throw std::runtime_error("WriteMM::apply: Cannot open file"); + } + + const char *field; + if (std::is_integral::value) { + field = MM_INT_STR; + } else if (std::is_floating_point::value) { + field = MM_REAL_STR; + } else { + throw std::runtime_error("WriteMM::apply: Unsupported data type"); + } + + const char *symmetry = MM_GENERAL_STR; + + f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; + + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + size_t nnz = arg->getNumNonZeros(); + + f << rows << " " << cols << " " << nnz << std::endl; + + const size_t *rowOffsets = arg->getRowOffsets(); + const size_t *colIdxs = arg->getColIdxs(); + const VT *values = arg->getValues(); + + std::vector>> colEntries(cols); + + for (size_t i = 0; i < rows; ++i) { + for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { + size_t j = colIdxs[idx]; + VT val = values[idx]; + colEntries[j].emplace_back(i, val); + } + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + for (size_t j = 0; j < cols; ++j) { + for (const auto &entry : colEntries[j]) { + size_t i = entry.first; + VT val = entry.second; + if (strcmp(field, MM_REAL_STR) == 0) { + if (val >= 0) { + f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val << std::endl; + } else { + f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val << std::endl; + } + } else { + f << i + 1 << " " << j + 1 << " " << val << std::endl; + } + } + } + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + + f.close(); +} + +template struct WriteMM>; +template struct WriteMM>; +template struct WriteMM>; +template struct WriteMM>; +template struct WriteMM>; +template struct WriteMM>; +template struct WriteMM>; +template struct WriteMM>; \ No newline at end of file diff --git a/src/runtime/local/io/WriteMM.h b/src/runtime/local/io/WriteMM.h index d9fa23148..33fcf35ce 100644 --- a/src/runtime/local/io/WriteMM.h +++ b/src/runtime/local/io/WriteMM.h @@ -18,222 +18,18 @@ template struct WriteMM { static void apply(const DTArg *arg, const char *filename) = delete; }; -// Convenience function -template void writeMM(const DTArg *arg, const char *filename) { WriteMM::apply(arg, filename); } - -// ---------------------------------------------------------------------------- -// DenseMatrix -// ---------------------------------------------------------------------------- +template void writeMM(const DTArg *arg, const char *filename); template struct WriteMM> { - static void apply(const DenseMatrix *arg, const char *filename) { - const char *format = MM_DENSE_STR; - std::ofstream f(filename); - if (!f.is_open()) { - throw std::runtime_error("WriteMM::apply: Cannot open file"); - } - - const char *field; - if (std::is_integral::value) { - field = MM_INT_STR; - } else if (std::is_floating_point::value) { - field = MM_REAL_STR; - } else { - throw std::runtime_error("WriteMM::apply: Unsupported data type"); - } - - const char *symmetry = MM_GENERAL_STR; - if (isSymmetric(arg)) { - symmetry = MM_SYMM_STR; - } else if (isSkewSymmetric(arg)) { - symmetry = MM_SKEW_STR; - } - - f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; - - size_t rows = arg->getNumRows(); - size_t cols = arg->getNumCols(); - f << rows << " " << cols << std::endl; - - const VT *values = arg->getValues(); - if (!values) { - throw std::runtime_error("WriteMM::apply: Null pointer for 'values' in DenseMatrix"); - } - - if (strcmp(symmetry, MM_GENERAL_STR) == 0) { - for (size_t i = 0; i < cols; ++i) { - for (size_t j = 0; j < rows; ++j) { - size_t idx = j * cols + i; - if (strcmp(field, MM_REAL_STR) == 0) - f << std::scientific << std::setprecision(13) << values[idx] << std::endl; - else - f << values[idx] << std::endl; - } - } - } else if (strcmp(symmetry, MM_SYMM_STR) == 0) { - for (size_t i = 0; i < cols; ++i) { - for (size_t j = i; j < rows; ++j) { - size_t idx = j * cols + i; - if (strcmp(field, MM_REAL_STR) == 0) - f << std::scientific << std::setprecision(13) << values[idx] << std::endl; - else - f << values[idx] << std::endl; - } - } - } else if (strcmp(symmetry, MM_SKEW_STR) == 0) { - for (size_t i = 0; i < cols; ++i) { - for (size_t j = i + 1; j < rows; ++j) { - size_t idx = j * cols + i; - if (strcmp(field, MM_REAL_STR) == 0) - f << std::scientific << std::setprecision(13) << values[idx] << std::endl; - else - f << values[idx] << std::endl; - } - } - } else { - throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); - } - f.close(); - } + static void apply(const DenseMatrix *arg, const char *filename); private: - static bool isSymmetric(const DenseMatrix *arg) { - size_t rows = arg->getNumRows(); - size_t cols = arg->getNumCols(); - if (rows != cols) - return false; - const VT *values = arg->getValues(); - for (size_t i = 0; i < cols; ++i) { - for (size_t j = i + 1; j < rows; ++j) { - size_t idx1 = j + i * rows; - size_t idx2 = i + j * rows; - if (values[idx1] != values[idx2]) - return false; - } - } - return true; - } - - static bool isSkewSymmetric(const DenseMatrix *arg) { - size_t rows = arg->getNumRows(); - size_t cols = arg->getNumCols(); - if (rows != cols) - return false; - const VT *values = arg->getValues(); - for (size_t i = 0; i < rows; ++i) { - size_t idx_diag = i + i * rows; - if (values[idx_diag] != 0) - return false; - } - for (size_t i = 0; i < cols; ++i) { - for (size_t j = i + 1; j < rows; ++j) { - size_t idx1 = j + i * rows; - size_t idx2 = i + j * rows; - if (values[idx1] != -values[idx2]) - return false; - } - } - return true; - } + static bool isSymmetric(const DenseMatrix *arg); + static bool isSkewSymmetric(const DenseMatrix *arg); }; -// ---------------------------------------------------------------------------- -// CSRMatrix -// ---------------------------------------------------------------------------- - template struct WriteMM> { - static void apply(const CSRMatrix *arg, const char *filename) { - const char *format = MM_SPARSE_STR; - std::ofstream f(filename); - if (!f.is_open()) { - throw std::runtime_error("WriteMM::apply: Cannot open file"); - } - - const char *field; - if (std::is_integral::value) { - field = MM_INT_STR; - } else if (std::is_floating_point::value) { - field = MM_REAL_STR; - } else { - throw std::runtime_error("WriteMM::apply: Unsupported data type"); - } - - const char *symmetry = MM_GENERAL_STR; - - f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; - - size_t rows = arg->getNumRows(); - size_t cols = arg->getNumCols(); - size_t nnz = countNNZ(arg, symmetry); - - f << rows << " " << cols << " " << nnz << std::endl; - - const size_t *rowOffsets = arg->getRowOffsets(); - const size_t *colIdxs = arg->getColIdxs(); - const VT *values = arg->getValues(); - - std::vector>> colEntries(cols); - - for (size_t i = 0; i < rows; ++i) { - for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { - size_t j = colIdxs[idx]; - VT val = values[idx]; - colEntries[j].emplace_back(i, val); - } - } - - if (strcmp(symmetry, MM_GENERAL_STR) == 0) { - for (size_t j = 0; j < cols; ++j) { - for (const auto &entry : colEntries[j]) { - size_t i = entry.first; - VT val = entry.second; - if (strcmp(field, MM_REAL_STR) == 0) { - if (val >= 0) { - f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val - << std::endl; - } else { - f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val - << std::endl; - } - } else { - f << i + 1 << " " << j + 1 << " " << val << std::endl; - } - } - } - } else { - throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); - } - - f.close(); - } - - private: - static size_t countNNZ(const CSRMatrix *arg, const char *symmetry) { - size_t nnz = 0; - size_t rows = arg->getNumRows(); - size_t cols = arg->getNumCols(); - - std::vector>> colEntries(cols); - - const size_t *rowOffsets = arg->getRowOffsets(); - const size_t *colIdxs = arg->getColIdxs(); - const VT *values = arg->getValues(); - - for (size_t i = 0; i < rows; ++i) { - for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { - size_t j = colIdxs[idx]; - colEntries[j].emplace_back(i, values[idx]); - } - } - - if (strcmp(symmetry, MM_GENERAL_STR) == 0) { - nnz = arg->getNumNonZeros(); - } else { - throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); - } - - return nnz; - } + static void apply(const CSRMatrix *arg, const char *filename); }; #endif // WRITE_MM_H diff --git a/test/api/cli/Utils.cpp b/test/api/cli/Utils.cpp index 0ea3284fc..4e61a4197 100644 --- a/test/api/cli/Utils.cpp +++ b/test/api/cli/Utils.cpp @@ -40,4 +40,33 @@ std::string readTextFile(const std::string &filePath) { std::string generalizeDataTypes(const std::string &str) { std::regex re("(DenseMatrix|CSRMatrix)"); return std::regex_replace(str, re, ""); +} + +bool compareFileContents(const std::string &filePath1, const std::string &filePath2) { + std::ifstream file1(filePath1, std::ios::binary); + std::ifstream file2(filePath2, std::ios::binary); + if (!file1.is_open() || !file2.is_open()) { + std::cerr << "Cannot open one or both files." << std::endl; + return false; + } + std::string line1, line2; + bool filesAreEqual = true; + while (std::getline(file1, line1)) { + if (!std::getline(file2, line2)) { + filesAreEqual = false; + break; + } + if (line1 != line2) { + filesAreEqual = false; + break; + } + } + if (filesAreEqual && std::getline(file2, line2)) { + if (!line2.empty()) { + filesAreEqual = false; + } + } + file1.close(); + file2.close(); + return filesAreEqual; } \ No newline at end of file diff --git a/test/api/cli/Utils.h b/test/api/cli/Utils.h index d2dc7a113..8db557faf 100644 --- a/test/api/cli/Utils.h +++ b/test/api/cli/Utils.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -591,4 +592,18 @@ void compareDaphneToSomeRefSimple(const std::string &dirPath, const std::string */ std::string generalizeDataTypes(const std::string &str); +/** + * @brief Compares the contents of two files line by line. + * + * This function opens two files specified by their paths and compares their + * contents line by line. If the contents of both files match exactly, + * the function returns true. If the files differ at any point, or if either + * file cannot be opened, the function returns false. + * + * @param filePath1 The path to the first file to be compared. + * @param filePath2 The path to the second file to be compared. + * @return true if the files are identical in content; false otherwise. + */ +bool compareFileContents(const std::string &filePath1, const std::string &filePath2); + #endif // TEST_API_CLI_UTILS_H diff --git a/test/api/cli/io/WriteTest.cpp b/test/api/cli/io/WriteTest.cpp index 0800945db..8b630a1e6 100644 --- a/test/api/cli/io/WriteTest.cpp +++ b/test/api/cli/io/WriteTest.cpp @@ -20,51 +20,12 @@ #include -#include -#include #include #include const std::string dirPath = "test/api/cli/io/"; const std::string dirPath2 = "test/runtime/local/io/"; -bool compareFiles(const std::string &filePath1, const std::string &filePath2) { - - std::ifstream file1(filePath1, std::ios::binary); - std::ifstream file2(filePath2, std::ios::binary); - - if (!file1.is_open() || !file2.is_open()) { - std::cerr << "Cannot open one or both files." << std::endl; - return false; - } - - std::string line1, line2; - bool filesAreEqual = true; - - while (std::getline(file1, line1)) { - if (!std::getline(file2, line2)) { - filesAreEqual = false; - break; - } - - if (line1 != line2) { - filesAreEqual = false; - break; - } - } - - if (filesAreEqual && std::getline(file2, line2)) { - if (!line2.empty()) { - filesAreEqual = false; - } - } - - file1.close(); - file2.close(); - - return filesAreEqual; -} - TEST_CASE("writeMatrixCSV_Full", TAG_IO) { std::string csvPath = dirPath + "matrix_full.csv"; std::filesystem::remove(csvPath); // remove old file if it still exists @@ -88,7 +49,7 @@ TEST_CASE("writeMatrixMtxaig", TAG_IO) { std::string resultPath = "out.mtx"; checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", std::string("inPath=\"" + expectedPath + "\"").c_str()); - CHECK(compareFiles(expectedPath, resultPath)); + CHECK(compareFileContents(expectedPath, resultPath)); std::filesystem::remove(resultPath); // remove old file if it still exists } @@ -97,7 +58,7 @@ TEST_CASE("writeMatrixMtxaik", TAG_IO) { std::string resultPath = "out.mtx"; checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", std::string("inPath=\"" + expectedPath + "\"").c_str()); - CHECK(compareFiles(expectedPath, resultPath)); + CHECK(compareFileContents(expectedPath, resultPath)); std::filesystem::remove(resultPath); // remove old file if it still exists } @@ -106,6 +67,6 @@ TEST_CASE("writeMatrixMtxais", TAG_IO) { std::string resultPath = "out.mtx"; checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", std::string("inPath=\"" + expectedPath + "\"").c_str()); - CHECK(compareFiles(expectedPath, resultPath)); + CHECK(compareFileContents(expectedPath, resultPath)); std::filesystem::remove(resultPath); // remove old file if it still exists } \ No newline at end of file diff --git a/test/runtime/local/io/WriteMMTest.cpp b/test/runtime/local/io/WriteMMTest.cpp index 4d02edfa3..233994f36 100644 --- a/test/runtime/local/io/WriteMMTest.cpp +++ b/test/runtime/local/io/WriteMMTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include @@ -28,47 +29,10 @@ #include #include #include -#include #include #include #include -bool compareContentsFromFile(const std::string &filePath1, const std::string &filePath2) { - std::ifstream file1(filePath1, std::ios::binary); - std::ifstream file2(filePath2, std::ios::binary); - - if (!file1.is_open() || !file2.is_open()) { - std::cerr << "Cannot open one or both files." << std::endl; - return false; - } - - std::string line1, line2; - bool filesAreEqual = true; - - while (std::getline(file1, line1)) { - if (!std::getline(file2, line2)) { - filesAreEqual = false; - break; - } - - if (line1 != line2) { - filesAreEqual = false; - break; - } - } - - if (filesAreEqual && std::getline(file2, line2)) { - if (!line2.empty()) { - filesAreEqual = false; - } - } - - file1.close(); - file2.close(); - - return filesAreEqual; -} - TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIG", TAG_IO, (DenseMatrix), (int32_t)) { using DT = TestType; DT *m = nullptr; @@ -85,7 +49,7 @@ TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIG", TAG_IO, (DenseMatrix), (int32_t)) { writeMM(m, resultPath); - CHECK(compareContentsFromFile(filename, resultPath)); + CHECK(compareFileContents(filename, resultPath)); std::filesystem::remove(resultPath); CHECK(m->get(0, 0) == 1); @@ -113,7 +77,7 @@ TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIK", TAG_IO, (DenseMatrix), (int32_t)) { writeMM(m, resultPath); - CHECK(compareContentsFromFile(filename, resultPath)); + CHECK(compareFileContents(filename, resultPath)); std::filesystem::remove(resultPath); CHECK(m->get(1, 0) == 1); @@ -143,7 +107,7 @@ TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIS", TAG_IO, (DenseMatrix), (int32_t)) { writeMM(m, resultPath); - CHECK(compareContentsFromFile(filename, resultPath)); + CHECK(compareFileContents(filename, resultPath)); std::filesystem::remove(resultPath); CHECK(m->get(1, 1) == 4); @@ -171,7 +135,7 @@ TEMPLATE_PRODUCT_TEST_CASE("WriteMM CIG (CSR)", TAG_IO, (CSRMatrix), (int32_t)) writeMM(m, resultPath); - CHECK(compareContentsFromFile(filename, resultPath)); + CHECK(compareFileContents(filename, resultPath)); std::filesystem::remove(resultPath); CHECK(m->get(0, 0) == 1); @@ -198,7 +162,7 @@ TEMPLATE_PRODUCT_TEST_CASE("WriteMM CRG (CSR)", TAG_IO, (CSRMatrix), (double)) { writeMM(m, resultPath); - CHECK(compareContentsFromFile(filename, resultPath)); + CHECK(compareFileContents(filename, resultPath)); std::filesystem::remove(resultPath); CHECK(m->get(5, 0) == 0.25599762);