Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
Tested:
PiperOrigin-RevId: 687227512
Change-Id: Id63277dc209b10f1bfcc95981ce8e80c64569014
  • Loading branch information
Reverb Team authored and copybara-github committed Oct 18, 2024
1 parent 886f20a commit 312090f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 40 deletions.
2 changes: 1 addition & 1 deletion reverb/cc/ops/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ClientHandleOp : public tensorflow::ResourceOpKernel<ClientResource> {
}

private:
tensorflow::Status CreateResource(ClientResource** ret) override {
absl::Status CreateResource(ClientResource** ret) override {
*ret = new ClientResource(server_address_);
return absl::OkStatus();
}
Expand Down
24 changes: 11 additions & 13 deletions reverb/cc/ops/pattern_dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,22 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel {
return tensorflow::data::kUnknownCardinality;
}

tensorflow::Status InputDatasets(
absl::Status InputDatasets(
std::vector<const tensorflow::data::DatasetBase*>* inputs)
const override {
inputs->push_back(input_);
return absl::OkStatus();
}

tensorflow::Status CheckExternalState() const override {
absl::Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}

protected:
tensorflow::Status AsGraphDefInternal(
tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx,
DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
tensorflow::AttrValue dtypes_attr;
tensorflow::AttrValue shapes_attr;
tensorflow::AttrValue configs_attr;
Expand Down Expand Up @@ -239,8 +239,7 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel {
configs_(configs),
clear_after_episode_(clear_after_episode) {}

tensorflow::Status Initialize(
tensorflow::data::IteratorContext* ctx) override {
absl::Status Initialize(tensorflow::data::IteratorContext* ctx) override {
structured_writer_ = std::make_unique<StructuredWriter>(
std::make_unique<QueueWriter>(required_keep_alive_, &data_),
configs_);
Expand All @@ -250,10 +249,9 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel {
ctx, &instantiated_captured_func_);
}

tensorflow::Status GetNextInternal(
tensorflow::data::IteratorContext* ctx,
std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx,
std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
// This needs to be thread-safe.
// We lock the full method because otherwise we would have several
// threads getting data from the input dataset and inserting into the
Expand Down Expand Up @@ -302,14 +300,14 @@ class ReverbPatternDatasetOp : public tensorflow::data::UnaryDatasetOpKernel {
}

protected:
tensorflow::Status SaveInternal(
absl::Status SaveInternal(
tensorflow::data::SerializationContext* ctx,
tensorflow::data::IteratorStateWriter* writer) override {
return tensorflow::errors::Unimplemented(
"SaveInternal is currently not supported");
}

tensorflow::Status RestoreInternal(
absl::Status RestoreInternal(
tensorflow::data::IteratorContext* ctx,
tensorflow::data::IteratorStateReader* reader) override {
return tensorflow::errors::Unimplemented(
Expand Down
24 changes: 11 additions & 13 deletions reverb/cc/ops/timestep_dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,20 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel {
return "ReverbTimestepDatasetOp::Dataset";
}

tensorflow::Status CheckExternalState() const override {
absl::Status CheckExternalState() const override {
return FailedPrecondition(DebugString(), " depends on external state.");
}

tensorflow::Status InputDatasets(
absl::Status InputDatasets(
std::vector<const DatasetBase*>* inputs) const override {
inputs->clear();
return absl::OkStatus();
}

protected:
tensorflow::Status AsGraphDefInternal(
tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx,
DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
tensorflow::AttrValue max_in_flight_samples_per_worker_attr;
tensorflow::AttrValue num_workers_attr;
tensorflow::AttrValue max_samples_per_stream_attr;
Expand Down Expand Up @@ -189,8 +189,7 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel {
shapes_(shapes),
rate_limited_(false) {}

tensorflow::Status Initialize(
tensorflow::data::IteratorContext* ctx) override {
absl::Status Initialize(tensorflow::data::IteratorContext* ctx) override {
constexpr auto kValidationTimeout = absl::Seconds(30);

// The shapes and dtypes contains metadata fields but the signature does
Expand Down Expand Up @@ -220,10 +219,9 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel {
return ToTensorflowStatus(status);
}

tensorflow::Status GetNextInternal(
tensorflow::data::IteratorContext* ctx,
std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx,
std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
REVERB_CHECK(sampler_.get() != nullptr) << "Initialize was not called?";

auto token = ctx->cancellation_manager()->get_cancellation_token();
Expand Down Expand Up @@ -265,13 +263,13 @@ class ReverbTimestepDatasetOp : public tensorflow::data::DatasetOpKernel {
}

protected:
tensorflow::Status SaveInternal(
absl::Status SaveInternal(
tensorflow::data::SerializationContext* ctx,
tensorflow::data::IteratorStateWriter* writer) override {
return Unimplemented("SaveInternal is currently not supported");
}

tensorflow::Status RestoreInternal(
absl::Status RestoreInternal(
tensorflow::data::IteratorContext* ctx,
tensorflow::data::IteratorStateReader* reader) override {
return Unimplemented("RestoreInternal is currently not supported");
Expand Down
24 changes: 11 additions & 13 deletions reverb/cc/ops/trajectory_dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,20 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel {
return "ReverbTrajectoryDatasetOp::Dataset";
}

tensorflow::Status CheckExternalState() const override {
absl::Status CheckExternalState() const override {
return FailedPrecondition(DebugString(), " depends on external state.");
}

tensorflow::Status InputDatasets(
absl::Status InputDatasets(
std::vector<const DatasetBase*>* inputs) const override {
inputs->clear();
return absl::OkStatus();
}

protected:
tensorflow::Status AsGraphDefInternal(
tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx,
DatasetGraphDefBuilder* b,
tensorflow::Node** output) const override {
tensorflow::AttrValue max_in_flight_samples_per_worker_attr;
tensorflow::AttrValue num_workers_attr;
tensorflow::AttrValue max_samples_per_stream_attr;
Expand Down Expand Up @@ -189,8 +189,7 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel {
shapes_(shapes),
rate_limited_(false) {}

tensorflow::Status Initialize(
tensorflow::data::IteratorContext* ctx) override {
absl::Status Initialize(tensorflow::data::IteratorContext* ctx) override {
constexpr auto kValidationTimeout = absl::Seconds(30);

// The shapes and dtypes contains metadata fields but the signature does
Expand Down Expand Up @@ -220,10 +219,9 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel {
return ToTensorflowStatus(status);
}

tensorflow::Status GetNextInternal(
tensorflow::data::IteratorContext* ctx,
std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx,
std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
REVERB_CHECK(sampler_.get() != nullptr) << "Initialize was not called?";

auto token = ctx->cancellation_manager()->get_cancellation_token();
Expand Down Expand Up @@ -264,13 +262,13 @@ class ReverbTrajectoryDatasetOp : public tensorflow::data::DatasetOpKernel {
}

protected:
tensorflow::Status SaveInternal(
absl::Status SaveInternal(
tensorflow::data::SerializationContext* ctx,
tensorflow::data::IteratorStateWriter* writer) override {
return Unimplemented("SaveInternal is currently not supported");
}

tensorflow::Status RestoreInternal(
absl::Status RestoreInternal(
tensorflow::data::IteratorContext* ctx,
tensorflow::data::IteratorStateReader* reader) override {
return Unimplemented("RestoreInternal is currently not supported");
Expand Down

0 comments on commit 312090f

Please sign in to comment.