diff --git a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp index 149662cd266e..26ef57aab65d 100644 --- a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp +++ b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp @@ -32,15 +32,15 @@ struct Fixture {} void assertSetAndGetTensor(const TensorSpec &tensorSpec) { Value::UP expTensor = makeTensor(tensorSpec); - EntryRef ref = store.setTensor(*expTensor); - Value::UP actTensor = store.getTensor(ref); + EntryRef ref = store.store_tensor(*expTensor); + Value::UP actTensor = store.get_tensor(ref); EXPECT_EQUAL(*expTensor, *actTensor); assertTensorView(ref, *expTensor); } void assertEmptyTensor(const TensorSpec &tensorSpec) { Value::UP expTensor = makeTensor(tensorSpec); EntryRef ref; - Value::UP actTensor = store.getTensor(ref); + Value::UP actTensor = store.get_tensor(ref); EXPECT_TRUE(actTensor.get() == nullptr); assertTensorView(ref, *expTensor); } diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 75f453ddcbc7..0d87881711cb 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -9,7 +9,6 @@ vespa_add_library(searchlib_tensor OBJECT dense_tensor_attribute_saver.cpp dense_tensor_store.cpp direct_tensor_attribute.cpp - direct_tensor_saver.cpp direct_tensor_store.cpp distance_calculator.cpp distance_function_factory.cpp @@ -29,13 +28,13 @@ vespa_add_library(searchlib_tensor OBJECT nearest_neighbor_index_saver.cpp serialized_fast_value_attribute.cpp small_subspaces_buffer_type.cpp - streamed_value_saver.cpp tensor_attribute.cpp tensor_buffer_operations.cpp tensor_buffer_store.cpp tensor_buffer_type_mapper.cpp tensor_deserialize.cpp tensor_store.cpp + tensor_store_saver.cpp reusable_set_visited_tracker.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp index d4353532309e..636c949be08e 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp @@ -112,7 +112,7 @@ void DenseTensorAttribute::internal_set_tensor(DocId docid, const vespalib::eval::Value& tensor) { consider_remove_from_index(docid); - EntryRef ref = _denseTensorStore.setTensor(tensor); + EntryRef ref = _denseTensorStore.store_tensor(tensor); setTensorRef(docid, ref); } @@ -229,10 +229,7 @@ DenseTensorAttribute::getTensor(DocId docId) const if (docId < getCommittedDocIdLimit()) { ref = acquire_entry_ref(docId); } - if (!ref.valid()) { - return {}; - } - return _denseTensorStore.getTensor(ref); + return _denseTensorStore.get_tensor(ref); } vespalib::eval::TypedCells diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp index 1bc84a7216db..ba7e85261461 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp @@ -11,6 +11,7 @@ using vespalib::datastore::CompactionContext; using vespalib::datastore::CompactionSpec; using vespalib::datastore::CompactionStrategy; +using vespalib::datastore::EntryRef; using vespalib::datastore::Handle; using vespalib::datastore::ICompactionContext; using vespalib::eval::CellType; @@ -147,19 +148,8 @@ DenseTensorStore::start_compact(const CompactionStrategy& compaction_strategy) return std::make_unique(*this, std::move(compacting_buffers)); } -std::unique_ptr -DenseTensorStore::getTensor(EntryRef ref) const -{ - if (!ref.valid()) { - return {}; - } - vespalib::eval::TypedCells cells_ref(getRawBuffer(ref), _type.cell_type(), getNumCells()); - return std::make_unique(_type, cells_ref); -} - -template -TensorStore::EntryRef -DenseTensorStore::setDenseTensor(const TensorType &tensor) +EntryRef +DenseTensorStore::store_tensor(const Value& tensor) { assert(tensor.type() == _type); auto cells = tensor.cells(); @@ -170,10 +160,29 @@ DenseTensorStore::setDenseTensor(const TensorType &tensor) return raw.ref; } -TensorStore::EntryRef -DenseTensorStore::setTensor(const vespalib::eval::Value &tensor) +EntryRef +DenseTensorStore::store_encoded_tensor(vespalib::nbostream& encoded) +{ + (void) encoded; + abort(); +} + +std::unique_ptr +DenseTensorStore::get_tensor(EntryRef ref) const +{ + if (!ref.valid()) { + return {}; + } + vespalib::eval::TypedCells cells_ref(getRawBuffer(ref), _type.cell_type(), getNumCells()); + return std::make_unique(_type, cells_ref); +} + +bool +DenseTensorStore::encode_stored_tensor(EntryRef ref, vespalib::nbostream& target) const { - return setDenseTensor(tensor); + (void) ref; + (void) target; + abort(); } } diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h index bd83772ee55a..1b25bdad4642 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h @@ -51,10 +51,6 @@ class DenseTensorStore : public TensorStore BufferType _bufferType; ValueType _type; // type of dense tensor std::vector _emptySpace; - - template - TensorStore::EntryRef - setDenseTensor(const TensorType &tensor); public: DenseTensorStore(const ValueType &type, std::shared_ptr allocator); ~DenseTensorStore() override; @@ -70,12 +66,15 @@ class DenseTensorStore : public TensorStore EntryRef move(EntryRef ref) override; vespalib::MemoryUsage update_stat(const vespalib::datastore::CompactionStrategy& compaction_strategy) override; std::unique_ptr start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy) override; - std::unique_ptr getTensor(EntryRef ref) const; + EntryRef store_tensor(const vespalib::eval::Value &tensor) override; + EntryRef store_encoded_tensor(vespalib::nbostream &encoded) override; + std::unique_ptr get_tensor(EntryRef ref) const override; + bool encode_stored_tensor(EntryRef ref, vespalib::nbostream &target) const override; + vespalib::eval::TypedCells get_typed_cells(EntryRef ref) const { return vespalib::eval::TypedCells(ref.valid() ? getRawBuffer(ref) : &_emptySpace[0], _type.cell_type(), getNumCells()); } - EntryRef setTensor(const vespalib::eval::Value &tensor); // The following method is meant to be used only for unit tests. uint32_t getArraySize() const { return _bufferType.getArraySize(); } }; diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp index 7730f340e016..d9fe025b4e5c 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp @@ -1,16 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "direct_tensor_attribute.h" -#include "direct_tensor_saver.h" - #include #include -#include -#include -#include - -#include "blob_sequence_reader.h" -#include "tensor_deserialize.h" using vespalib::eval::FastValueBuilderFactory; @@ -27,39 +19,6 @@ DirectTensorAttribute::~DirectTensorAttribute() _tensorStore.clearHoldLists(); } -bool -DirectTensorAttribute::onLoad(vespalib::Executor *) -{ - BlobSequenceReader tensorReader(*this); - if (!tensorReader.hasData()) { - return false; - } - setCreateSerialNum(tensorReader.getCreateSerialNum()); - assert(tensorReader.getVersion() == getVersion()); - uint32_t numDocs = tensorReader.getDocIdLimit(); - _refVector.reset(); - _refVector.unsafe_reserve(numDocs); - vespalib::Array buffer(1024); - for (uint32_t lid = 0; lid < numDocs; ++lid) { - uint32_t tensorSize = tensorReader.getNextSize(); - if (tensorSize != 0) { - if (tensorSize > buffer.size()) { - buffer.resize(tensorSize + 1024); - } - tensorReader.readBlob(&buffer[0], tensorSize); - auto tensor = deserialize_tensor(&buffer[0], tensorSize); - EntryRef ref = _direct_store.store_tensor(std::move(tensor)); - _refVector.push_back(AtomicEntryRef(ref)); - } else { - EntryRef invalid; - _refVector.push_back(AtomicEntryRef(invalid)); - } - } - setNumDocs(numDocs); - setCommittedDocIdLimit(numDocs); - return true; -} - void DirectTensorAttribute::set_tensor(DocId lid, std::unique_ptr tensor) { @@ -129,15 +88,4 @@ DirectTensorAttribute::get_tensor_ref(DocId docId) const return *ptr; } -std::unique_ptr -DirectTensorAttribute::onInitSave(vespalib::stringref fileName) -{ - vespalib::GenerationHandler::Guard guard(getGenerationHandler().takeGuard()); - return std::make_unique - (std::move(guard), - this->createAttributeHeader(fileName), - getRefCopy(), - _direct_store); -} - } // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h index 2dfb5c1efcd5..6466c6f75375 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h @@ -21,8 +21,6 @@ class DirectTensorAttribute final : public TensorAttribute const document::TensorUpdate &update, bool create_empty_if_non_existing) override; std::unique_ptr getTensor(DocId docId) const override; - bool onLoad(vespalib::Executor *executor) override; - std::unique_ptr onInitSave(vespalib::stringref fileName) override; void set_tensor(DocId docId, std::unique_ptr tensor); const vespalib::eval::Value &get_tensor_ref(DocId docId) const override; bool supports_get_tensor_ref() const override { return true; } diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_saver.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_saver.cpp deleted file mode 100644 index 0de4491cfcc4..000000000000 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_saver.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "direct_tensor_saver.h" -#include "direct_tensor_store.h" - -#include -#include -#include -#include - -using vespalib::GenerationHandler; - -namespace search::tensor { - -DirectTensorAttributeSaver:: -DirectTensorAttributeSaver(GenerationHandler::Guard &&guard, - const attribute::AttributeHeader &header, - RefCopyVector &&refs, - const DirectTensorStore &tensorStore) - : AttributeSaver(std::move(guard), header), - _refs(std::move(refs)), - _tensorStore(tensorStore) -{ -} - - -DirectTensorAttributeSaver::~DirectTensorAttributeSaver() -{ -} - -bool -DirectTensorAttributeSaver::onSave(IAttributeSaveTarget &saveTarget) -{ - auto datWriter = saveTarget.datWriter().allocBufferWriter(); - const uint32_t docIdLimit(_refs.size()); - vespalib::nbostream stream; - for (uint32_t lid = 0; lid < docIdLimit; ++lid) { - const vespalib::eval::Value *tensor = _tensorStore.get_tensor_ptr(_refs[lid]); - if (tensor) { - stream.clear(); - encode_value(*tensor, stream); - uint32_t sz = stream.size(); - datWriter->write(&sz, sizeof(sz)); - datWriter->write(stream.peek(), stream.size()); - } else { - uint32_t sz = 0; - datWriter->write(&sz, sizeof(sz)); - } - } - datWriter->flush(); - return true; -} - -} // namespace search::tensor diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_saver.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_saver.h deleted file mode 100644 index 132e1570f0ff..000000000000 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_saver.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include -#include "tensor_attribute.h" - -namespace search::tensor { - -class DirectTensorStore; - -/* - * Class for saving a tensor attribute. - */ -class DirectTensorAttributeSaver : public AttributeSaver -{ -public: - using RefCopyVector = TensorAttribute::RefCopyVector; -private: - using GenerationHandler = vespalib::GenerationHandler; - - RefCopyVector _refs; - const DirectTensorStore &_tensorStore; - - bool onSave(IAttributeSaveTarget &saveTarget) override; -public: - DirectTensorAttributeSaver(GenerationHandler::Guard &&guard, - const attribute::AttributeHeader &header, - RefCopyVector &&refs, - const DirectTensorStore &tensorStore); - - virtual ~DirectTensorAttributeSaver(); -}; - -} // namespace search::tensor diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp index fba1d4946900..1184cca37e7c 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp @@ -1,7 +1,10 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "direct_tensor_store.h" +#include "tensor_deserialize.h" +#include #include +#include #include #include #include @@ -14,6 +17,8 @@ using vespalib::datastore::CompactionSpec; using vespalib::datastore::CompactionStrategy; using vespalib::datastore::EntryRef; using vespalib::datastore::ICompactionContext; +using vespalib::eval::FastValueBuilderFactory; +using vespalib::eval::Value; namespace search::tensor { @@ -54,13 +59,6 @@ DirectTensorStore::DirectTensorStore() DirectTensorStore::~DirectTensorStore() = default; -EntryRef -DirectTensorStore::store_tensor(std::unique_ptr tensor) -{ - assert(tensor); - return add_entry(TensorSP(std::move(tensor))); -} - void DirectTensorStore::holdTensor(EntryRef ref) { @@ -100,4 +98,42 @@ DirectTensorStore::start_compact(const CompactionStrategy& compaction_strategy) return std::make_unique(*this, std::move(compacting_buffers)); } +EntryRef +DirectTensorStore::store_tensor(std::unique_ptr tensor) +{ + assert(tensor); + return add_entry(std::move(tensor)); +} + +EntryRef +DirectTensorStore::store_tensor(const Value& tensor) +{ + return add_entry(FastValueBuilderFactory::get().copy(tensor)); +} + +EntryRef +DirectTensorStore::store_encoded_tensor(vespalib::nbostream& encoded) +{ + return add_entry(deserialize_tensor(encoded)); +} + +std::unique_ptr +DirectTensorStore::get_tensor(EntryRef ref) const +{ + if (!ref.valid()) { + return {}; + } + return FastValueBuilderFactory::get().copy(*_tensor_store.getEntry(ref)); +} + +bool +DirectTensorStore::encode_stored_tensor(EntryRef ref, vespalib::nbostream& target) const +{ + if (!ref.valid()) { + return false; + } + vespalib::eval::encode_value(*_tensor_store.getEntry(ref), target); + return true; +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h index 658d1ab0549a..ff9540a27b31 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h @@ -52,6 +52,10 @@ class DirectTensorStore : public TensorStore { EntryRef move(EntryRef ref) override; vespalib::MemoryUsage update_stat(const vespalib::datastore::CompactionStrategy& compaction_strategy) override; std::unique_ptr start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy) override; + EntryRef store_tensor(const vespalib::eval::Value& tensor) override; + EntryRef store_encoded_tensor(vespalib::nbostream& encoded) override; + std::unique_ptr get_tensor(EntryRef ref) const override; + bool encode_stored_tensor(EntryRef ref, vespalib::nbostream& target) const override; }; } diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp index 94bf3f1a37b1..a24059b3f7c3 100644 --- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp @@ -1,17 +1,13 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "serialized_fast_value_attribute.h" -#include "streamed_value_saver.h" #include -#include #include #include LOG_SETUP(".searchlib.tensor.serialized_fast_value_attribute"); -#include "blob_sequence_reader.h" - using namespace vespalib; using namespace vespalib::eval; @@ -50,50 +46,4 @@ SerializedFastValueAttribute::getTensor(DocId docId) const return _tensorBufferStore.get_tensor(ref); } -bool -SerializedFastValueAttribute::onLoad(vespalib::Executor *) -{ - BlobSequenceReader tensorReader(*this); - if (!tensorReader.hasData()) { - return false; - } - setCreateSerialNum(tensorReader.getCreateSerialNum()); - assert(tensorReader.getVersion() == getVersion()); - uint32_t numDocs(tensorReader.getDocIdLimit()); - _refVector.reset(); - _refVector.unsafe_reserve(numDocs); - vespalib::Array buffer(1024); - for (uint32_t lid = 0; lid < numDocs; ++lid) { - uint32_t tensorSize = tensorReader.getNextSize(); - if (tensorSize != 0) { - if (tensorSize > buffer.size()) { - buffer.resize(tensorSize + 1024); - } - tensorReader.readBlob(&buffer[0], tensorSize); - vespalib::nbostream source(&buffer[0], tensorSize); - EntryRef ref = _tensorBufferStore.store_encoded_tensor(source); - _refVector.push_back(AtomicEntryRef(ref)); - } else { - EntryRef invalid; - _refVector.push_back(AtomicEntryRef(invalid)); - } - } - setNumDocs(numDocs); - setCommittedDocIdLimit(numDocs); - return true; -} - - -std::unique_ptr -SerializedFastValueAttribute::onInitSave(vespalib::stringref fileName) -{ - vespalib::GenerationHandler::Guard guard(getGenerationHandler(). - takeGuard()); - return std::make_unique - (std::move(guard), - this->createAttributeHeader(fileName), - getRefCopy(), - _tensorBufferStore); -} - } diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h index 5dd49a2bbc40..2124ddeb70af 100644 --- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h @@ -24,8 +24,6 @@ class SerializedFastValueAttribute : public TensorAttribute { ~SerializedFastValueAttribute() override; void setTensor(DocId docId, const vespalib::eval::Value &tensor) override; std::unique_ptr getTensor(DocId docId) const override; - bool onLoad(vespalib::Executor *executor) override; - std::unique_ptr onInitSave(vespalib::stringref fileName) override; }; } diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index cdaea07176a2..99a30b59bd1d 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensor_attribute.h" +#include "blob_sequence_reader.h" +#include "tensor_store_saver.h" #include #include #include @@ -277,6 +279,51 @@ TensorAttribute::getRefCopy() const return result; } +bool +TensorAttribute::onLoad(vespalib::Executor*) +{ + BlobSequenceReader tensorReader(*this); + if (!tensorReader.hasData()) { + return false; + } + setCreateSerialNum(tensorReader.getCreateSerialNum()); + assert(tensorReader.getVersion() == getVersion()); + uint32_t numDocs = tensorReader.getDocIdLimit(); + _refVector.reset(); + _refVector.unsafe_reserve(numDocs); + vespalib::Array buffer(1024); + for (uint32_t lid = 0; lid < numDocs; ++lid) { + uint32_t tensorSize = tensorReader.getNextSize(); + if (tensorSize != 0) { + if (tensorSize > buffer.size()) { + buffer.resize(tensorSize + 1024); + } + tensorReader.readBlob(&buffer[0], tensorSize); + vespalib::nbostream source(&buffer[0], tensorSize); + EntryRef ref = _tensorStore.store_encoded_tensor(source); + _refVector.push_back(AtomicEntryRef(ref)); + } else { + EntryRef invalid; + _refVector.push_back(AtomicEntryRef(invalid)); + } + } + setNumDocs(numDocs); + setCommittedDocIdLimit(numDocs); + return true; +} + +std::unique_ptr +TensorAttribute::onInitSave(vespalib::stringref fileName) +{ + vespalib::GenerationHandler::Guard guard(getGenerationHandler(). + takeGuard()); + return std::make_unique + (std::move(guard), + this->createAttributeHeader(fileName), + getRefCopy(), + _tensorStore); +} + void TensorAttribute::update_tensor(DocId docId, const document::TensorUpdate &update, diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h index 411efcd8feaa..7cfbb68eac7e 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h @@ -36,6 +36,8 @@ class TensorAttribute : public NotImplementedAttribute, public ITensorAttribute void populate_state(vespalib::slime::Cursor& object) const; void populate_address_space_usage(AddressSpaceUsage& usage) const override; EntryRef acquire_entry_ref(DocId doc_id) const noexcept { return _refVector.acquire_elem_ref(doc_id).load_acquire(); } + bool onLoad(vespalib::Executor *executor) override; + std::unique_ptr onInitSave(vespalib::stringref fileName) override; public: using RefCopyVector = vespalib::Array; diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp index eff6ac9f374e..800311adfd66 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensor_buffer_store.h" +#include #include #include #include @@ -10,6 +11,7 @@ #include #include +using document::DeserializeException; using vespalib::alloc::MemoryAllocator; using vespalib::datastore::CompactionContext; using vespalib::datastore::CompactionStrategy; @@ -90,6 +92,9 @@ TensorBufferStore::store_encoded_tensor(vespalib::nbostream &encoded) { const auto &factory = StreamedValueBuilderFactory::get(); auto val = vespalib::eval::decode_value(encoded, factory); + if (!encoded.empty()) { + throw DeserializeException("Leftover bytes deserializing tensor attribute value.", VESPA_STRLOC); + } return store_tensor(*val); } diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h index 585bbd7a0c3e..6611660b410f 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h @@ -30,10 +30,10 @@ class TensorBufferStore : public TensorStore EntryRef move(EntryRef ref) override; vespalib::MemoryUsage update_stat(const vespalib::datastore::CompactionStrategy& compaction_strategy) override; std::unique_ptr start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy) override; - EntryRef store_tensor(const vespalib::eval::Value &tensor); - EntryRef store_encoded_tensor(vespalib::nbostream &encoded); - std::unique_ptr get_tensor(EntryRef ref) const; - bool encode_stored_tensor(EntryRef ref, vespalib::nbostream &target) const; + EntryRef store_tensor(const vespalib::eval::Value& tensor) override; + EntryRef store_encoded_tensor(vespalib::nbostream& encoded) override; + std::unique_ptr get_tensor(EntryRef ref) const override; + bool encode_stored_tensor(EntryRef ref, vespalib::nbostream& target) const override; }; } diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp index a8399d9ddeb2..791662caed76 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp @@ -5,6 +5,7 @@ #include #include #include +#include using document::DeserializeException; using vespalib::eval::FastValueBuilderFactory; @@ -25,10 +26,4 @@ std::unique_ptr deserialize_tensor(vespalib::nbostream &buffer) } } -std::unique_ptr deserialize_tensor(const void *data, size_t size) -{ - vespalib::nbostream wrapStream(data, size); - return deserialize_tensor(wrapStream); -} - } // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h index 7d1ede291677..18b9731b30be 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h @@ -1,12 +1,13 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include -#include +#pragma once -namespace search::tensor { +#include -extern std::unique_ptr -deserialize_tensor(const void *data, size_t size); +namespace vespalib { class nbostream; } +namespace vespalib::eval { struct Value; } + +namespace search::tensor { extern std::unique_ptr deserialize_tensor(vespalib::nbostream &stream); diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_store.h b/searchlib/src/vespa/searchlib/tensor/tensor_store.h index 90bc82c4fde3..e2426d2e8995 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_store.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_store.h @@ -8,6 +8,7 @@ #include #include +namespace vespalib { class nbostream; } namespace vespalib::datastore { struct ICompactionContext; } namespace vespalib::eval { struct Value; } @@ -41,6 +42,11 @@ class TensorStore : public vespalib::datastore::ICompactable virtual std::unique_ptr start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy) = 0; + virtual EntryRef store_tensor(const vespalib::eval::Value& tensor) = 0; + virtual EntryRef store_encoded_tensor(vespalib::nbostream& encoded) = 0; + virtual std::unique_ptr get_tensor(EntryRef ref) const = 0; + virtual bool encode_stored_tensor(EntryRef ref, vespalib::nbostream& target) const = 0; + // Inherit doc from DataStoreBase void trimHoldLists(generation_t usedGen) { _store.trimHoldLists(usedGen); diff --git a/searchlib/src/vespa/searchlib/tensor/streamed_value_saver.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_store_saver.cpp similarity index 72% rename from searchlib/src/vespa/searchlib/tensor/streamed_value_saver.cpp rename to searchlib/src/vespa/searchlib/tensor/tensor_store_saver.cpp index 4c188bb33704..0963e79b0dd0 100644 --- a/searchlib/src/vespa/searchlib/tensor/streamed_value_saver.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_store_saver.cpp @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "streamed_value_saver.h" -#include "tensor_buffer_store.h" +#include "tensor_store_saver.h" +#include "tensor_store.h" #include #include @@ -11,21 +11,21 @@ using vespalib::GenerationHandler; namespace search::tensor { -StreamedValueSaver:: -StreamedValueSaver(GenerationHandler::Guard &&guard, - const attribute::AttributeHeader &header, - RefCopyVector &&refs, - const TensorBufferStore &tensorStore) +TensorStoreSaver:: +TensorStoreSaver(GenerationHandler::Guard &&guard, + const attribute::AttributeHeader &header, + RefCopyVector &&refs, + const TensorStore &tensorStore) : AttributeSaver(std::move(guard), header), _refs(std::move(refs)), _tensorStore(tensorStore) { } -StreamedValueSaver::~StreamedValueSaver() = default; +TensorStoreSaver::~TensorStoreSaver() = default; bool -StreamedValueSaver::onSave(IAttributeSaveTarget &saveTarget) +TensorStoreSaver::onSave(IAttributeSaveTarget &saveTarget) { auto datWriter = saveTarget.datWriter().allocBufferWriter(); const uint32_t docIdLimit(_refs.size()); diff --git a/searchlib/src/vespa/searchlib/tensor/streamed_value_saver.h b/searchlib/src/vespa/searchlib/tensor/tensor_store_saver.h similarity index 58% rename from searchlib/src/vespa/searchlib/tensor/streamed_value_saver.h rename to searchlib/src/vespa/searchlib/tensor/tensor_store_saver.h index 0ce864769f79..a4bf6e07519c 100644 --- a/searchlib/src/vespa/searchlib/tensor/streamed_value_saver.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_store_saver.h @@ -7,12 +7,10 @@ namespace search::tensor { -class TensorBufferStore; - /* * Class for saving a tensor attribute. */ -class StreamedValueSaver : public AttributeSaver +class TensorStoreSaver : public AttributeSaver { public: using RefCopyVector = TensorAttribute::RefCopyVector; @@ -20,16 +18,16 @@ class StreamedValueSaver : public AttributeSaver using GenerationHandler = vespalib::GenerationHandler; RefCopyVector _refs; - const TensorBufferStore &_tensorStore; + const TensorStore& _tensorStore; bool onSave(IAttributeSaveTarget &saveTarget) override; public: - StreamedValueSaver(GenerationHandler::Guard &&guard, - const attribute::AttributeHeader &header, - RefCopyVector &&refs, - const TensorBufferStore &tensorStore); + TensorStoreSaver(GenerationHandler::Guard &&guard, + const attribute::AttributeHeader &header, + RefCopyVector &&refs, + const TensorStore &tensorStore); - virtual ~StreamedValueSaver(); + virtual ~TensorStoreSaver(); }; } // namespace search::tensor