diff --git a/reverb/cc/ops/client.cc b/reverb/cc/ops/client.cc index c1040bc..0afc5c8 100644 --- a/reverb/cc/ops/client.cc +++ b/reverb/cc/ops/client.cc @@ -65,7 +65,7 @@ class ClientHandleOp : public tensorflow::ResourceOpKernel { } private: - tensorflow::Status CreateResource(ClientResource** ret) override { + absl::Status CreateResource(ClientResource** ret) override { *ret = new ClientResource(server_address_); return absl::OkStatus(); } diff --git a/reverb/cc/ops/pattern_dataset.cc b/reverb/cc/ops/pattern_dataset.cc index a741796..4b0af1e 100644 --- a/reverb/cc/ops/pattern_dataset.cc +++ b/reverb/cc/ops/pattern_dataset.cc @@ -155,22 +155,22 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel { return tensorflow::data::kUnknownCardinality; } - tensorflow::Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - tensorflow::Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - tensorflow::Status AsGraphDefInternal( - tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, - tensorflow::Node** output) const override { + absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx, + DatasetGraphDefBuilder* b, + tensorflow::Node** output) const override { tensorflow::AttrValue dtypes_attr; tensorflow::AttrValue shapes_attr; tensorflow::AttrValue configs_attr; @@ -239,8 +239,7 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel { configs_(configs), clear_after_episode_(clear_after_episode) {} - tensorflow::Status Initialize( - tensorflow::data::IteratorContext* ctx) override { + absl::Status Initialize(tensorflow::data::IteratorContext* ctx) override { structured_writer_ = std::make_unique( std::make_unique(required_keep_alive_, &data_), configs_); @@ -250,10 +249,9 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel { ctx, &instantiated_captured_func_); } - tensorflow::Status GetNextInternal( - tensorflow::data::IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { // This needs to be thread-safe. // We lock the full method because otherwise we would have several // threads getting data from the input dataset and inserting into the @@ -302,14 +300,14 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel { } protected: - tensorflow::Status SaveInternal( + absl::Status SaveInternal( tensorflow::data::SerializationContext* ctx, tensorflow::data::IteratorStateWriter* writer) override { return tensorflow::errors::Unimplemented( "SaveInternal is currently not supported"); } - tensorflow::Status RestoreInternal( + absl::Status RestoreInternal( tensorflow::data::IteratorContext* ctx, tensorflow::data::IteratorStateReader* reader) override { return tensorflow::errors::Unimplemented( diff --git a/reverb/cc/ops/timestep_dataset.cc b/reverb/cc/ops/timestep_dataset.cc index 6eb7ed1..17226e5 100644 --- a/reverb/cc/ops/timestep_dataset.cc +++ b/reverb/cc/ops/timestep_dataset.cc @@ -113,20 +113,20 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel { return "ReverbTimestepDatasetOp::Dataset"; } - tensorflow::Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return FailedPrecondition(DebugString(), " depends on external state."); } - tensorflow::Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } protected: - tensorflow::Status AsGraphDefInternal( - tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, - tensorflow::Node** output) const override { + absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx, + DatasetGraphDefBuilder* b, + tensorflow::Node** output) const override { tensorflow::AttrValue max_in_flight_samples_per_worker_attr; tensorflow::AttrValue num_workers_attr; tensorflow::AttrValue max_samples_per_stream_attr; @@ -189,8 +189,7 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel { shapes_(shapes), rate_limited_(false) {} - tensorflow::Status Initialize( - tensorflow::data::IteratorContext* ctx) override { + absl::Status Initialize(tensorflow::data::IteratorContext* ctx) override { constexpr auto kValidationTimeout = absl::Seconds(30); // The shapes and dtypes contains metadata fields but the signature does @@ -220,10 +219,9 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel { return ToTensorflowStatus(status); } - tensorflow::Status GetNextInternal( - tensorflow::data::IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { REVERB_CHECK(sampler_.get() != nullptr) << "Initialize was not called?"; auto token = ctx->cancellation_manager()->get_cancellation_token(); @@ -265,13 +263,13 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel { } protected: - tensorflow::Status SaveInternal( + absl::Status SaveInternal( tensorflow::data::SerializationContext* ctx, tensorflow::data::IteratorStateWriter* writer) override { return Unimplemented("SaveInternal is currently not supported"); } - tensorflow::Status RestoreInternal( + absl::Status RestoreInternal( tensorflow::data::IteratorContext* ctx, tensorflow::data::IteratorStateReader* reader) override { return Unimplemented("RestoreInternal is currently not supported"); diff --git a/reverb/cc/ops/trajectory_dataset.cc b/reverb/cc/ops/trajectory_dataset.cc index cd91b2c..191d0d8 100644 --- a/reverb/cc/ops/trajectory_dataset.cc +++ b/reverb/cc/ops/trajectory_dataset.cc @@ -113,20 +113,20 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel { return "ReverbTrajectoryDatasetOp::Dataset"; } - tensorflow::Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return FailedPrecondition(DebugString(), " depends on external state."); } - tensorflow::Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } protected: - tensorflow::Status AsGraphDefInternal( - tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, - tensorflow::Node** output) const override { + absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx, + DatasetGraphDefBuilder* b, + tensorflow::Node** output) const override { tensorflow::AttrValue max_in_flight_samples_per_worker_attr; tensorflow::AttrValue num_workers_attr; tensorflow::AttrValue max_samples_per_stream_attr; @@ -189,8 +189,7 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel { shapes_(shapes), rate_limited_(false) {} - tensorflow::Status Initialize( - tensorflow::data::IteratorContext* ctx) override { + absl::Status Initialize(tensorflow::data::IteratorContext* ctx) override { constexpr auto kValidationTimeout = absl::Seconds(30); // The shapes and dtypes contains metadata fields but the signature does @@ -220,10 +219,9 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel { return ToTensorflowStatus(status); } - tensorflow::Status GetNextInternal( - tensorflow::data::IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { REVERB_CHECK(sampler_.get() != nullptr) << "Initialize was not called?"; auto token = ctx->cancellation_manager()->get_cancellation_token(); @@ -264,13 +262,13 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel { } protected: - tensorflow::Status SaveInternal( + absl::Status SaveInternal( tensorflow::data::SerializationContext* ctx, tensorflow::data::IteratorStateWriter* writer) override { return Unimplemented("SaveInternal is currently not supported"); } - tensorflow::Status RestoreInternal( + absl::Status RestoreInternal( tensorflow::data::IteratorContext* ctx, tensorflow::data::IteratorStateReader* reader) override { return Unimplemented("RestoreInternal is currently not supported");