Skip to content

Commit

Permalink
make workaround for hybirdstream test
Browse files Browse the repository at this point in the history
  • Loading branch information
ashione committed Apr 13, 2024
1 parent 1d234f9 commit 65dec4a
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ public int[] partition(T value, int currentIndex, int numPartition) {

@Override
public int[] partition(T record, int numPartition) {
// TODO
return new int[0];
seq = (seq + 1) % numPartition;
partitions[0] = seq;
return partitions;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public void setFunction(F function) {

@Override
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
LOG.info("Abstract {}, {} open : {}.", this.getId(), this.getName(), collectorList.size());
this.collectorList = collectorList;
this.runtimeContext = runtimeContext;
if (runtimeContext != null && runtimeContext.getOpConfig() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ChainedOperator(
@Override
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
// Dont' call super.open() as we `open` every operator separately.
LOG.info("chainedOperator open.");
LOG.info("ChainedOperator open.");
for (int i = 0; i < operators.size(); i++) {
StreamOperator operator = operators.get(i);
List<Collector> succeedingCollectors = new ArrayList<>();
Expand All @@ -77,6 +77,14 @@ public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
(collector.getId() == operator.getId()
&& collector.getDownStreamOpId() == subOperator.getId()))
.collect(Collectors.toList()));
// FIXME(lingxuan.zlx): Workaround for edge mismatch, see more detail from
// https://github.com/ray-project/mobius/issues/67.
if (succeedingCollectors.isEmpty()) {
succeedingCollectors.addAll(
collectorList.stream()
.filter(x -> (x.getDownStreamOpId() == subOperator.getId()))
.collect(Collectors.toList()));
}
}
});
operator.open(succeedingCollectors, createRuntimeContext(runtimeContext, i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ public OutputCollector(

@Override
public void collect(Record record) {
LOGGER.info("Collect in output {}.", record);
int[] partitions = this.partition.partition(record, outputQueues.length);
ByteBuffer javaBuffer = null;
ByteBuffer crossLangBuffer = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.common.tuple.Tuple2;
import io.ray.streaming.runtime.config.worker.WorkerInternalConfig;
import io.ray.streaming.runtime.context.ContextBackend;
import io.ray.streaming.runtime.core.checkpoint.OperatorCheckpointInfo;
Expand Down Expand Up @@ -190,8 +191,15 @@ private void openProcessor() {
Map<String, List<String>> opGroupedChannelId = new HashMap<>();
Map<String, List<BaseActorHandle>> opGroupedActor = new HashMap<>();
Map<String, Partition> opPartitionMap = new HashMap<>();
Map<String, Tuple2<Integer, Integer>> opIdAndDownStreamIdMap = new HashMap<>();
for (int i = 0; i < outputEdges.size(); ++i) {
ExecutionEdge edge = outputEdges.get(i);
LOG.info(
"Upstream {} {}, downstream {} {}.",
edge.getSource().getExecutionVertexName(),
edge.getSource().getOperator().getId(),
edge.getTargetExecutionJobVertexName(),
edge.getTarget().getOperator().getId());
String opName = edge.getTargetExecutionJobVertexName();
if (!opPartitionMap.containsKey(opName)) {
opGroupedChannelId.put(opName, new ArrayList<>());
Expand All @@ -202,13 +210,19 @@ private void openProcessor() {
.get(opName)
.add(new ArrayList<>(executionVertex.getChannelIdOutputActorMap().values()).get(i));
opPartitionMap.put(opName, edge.getPartition());
opIdAndDownStreamIdMap.put(
opName,
Tuple2.of(
edge.getSource().getOperator().getId(), edge.getTarget().getOperator().getId()));
}
opPartitionMap
.keySet()
.forEach(
opName -> {
collectors.add(
new OutputCollector(
opIdAndDownStreamIdMap.get(opName).f0,
opIdAndDownStreamIdMap.get(opName).f1,
writer,
opGroupedChannelId.get(opName),
opGroupedActor.get(opName),
Expand All @@ -217,7 +231,10 @@ private void openProcessor() {

RuntimeContext runtimeContext =
new StreamingTaskRuntimeContext(executionVertex, lastCheckpointId);

for (Collector collector : collectors) {
LOG.info(
"Collector id {}, downstream id {}.", collector.getId(), collector.getDownStreamOpId());
}
processor.open(collectors, runtimeContext);
}

Expand Down
11 changes: 6 additions & 5 deletions streaming/python/raystreaming/tests/simple/test_function.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from ray.streaming import function
from ray.streaming.runtime import gateway_client
from raystreaming import function
from raystreaming.runtime import gateway_client


def test_get_simple_function_class():
simple_map_func_class = function._get_simple_function_class(function.MapFunction)
simple_map_func_class = function._get_simple_function_class(
function.MapFunction)
assert simple_map_func_class is function.SimpleMapFunction


class MapFunc(function.MapFunction):

def map(self, value):
return str(value)

Expand All @@ -16,7 +18,6 @@ def test_load_function():
# function_bytes, module_name, function_name/class_name,
# function_interface
descriptor_func_bytes = gateway_client.serialize(
[None, __name__, MapFunc.__name__, "MapFunction"]
)
[None, __name__, MapFunc.__name__, "MapFunction"])
func = function.load_function(descriptor_func_bytes)
assert type(func) is MapFunc
19 changes: 10 additions & 9 deletions streaming/python/raystreaming/tests/simple/test_operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ray.streaming import function
from ray.streaming import operator
from ray.streaming.operator import OperatorType
from ray.streaming.runtime import gateway_client
from raystreaming import function
from raystreaming import operator
from raystreaming.operator import OperatorType
from raystreaming.runtime import gateway_client


def test_create_operator_with_func():
Expand All @@ -11,11 +11,13 @@ def test_create_operator_with_func():


class MapFunc(function.MapFunction):

def map(self, value):
return str(value)


class EmptyOperator(operator.StreamOperator):

def __init__(self):
super().__init__(function.EmptyFunction())

Expand All @@ -26,13 +28,12 @@ def operator_type(self) -> OperatorType:
def test_load_operator():
# function_bytes, module_name, class_name,
descriptor_func_bytes = gateway_client.serialize(
[None, __name__, MapFunc.__name__, "MapFunction"]
)
descriptor_op_bytes = gateway_client.serialize([descriptor_func_bytes, "", ""])
[None, __name__, MapFunc.__name__, "MapFunction"])
descriptor_op_bytes = gateway_client.serialize(
[descriptor_func_bytes, "", ""])
map_operator = operator.load_operator(descriptor_op_bytes)
assert type(map_operator) is operator.MapOperator
descriptor_op_bytes = gateway_client.serialize(
[None, __name__, EmptyOperator.__name__]
)
[None, __name__, EmptyOperator.__name__])
test_operator = operator.load_operator(descriptor_op_bytes)
assert isinstance(test_operator, EmptyOperator)
8 changes: 7 additions & 1 deletion streaming/src/data_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ StreamingStatus DataReader::InitChannel(

for (const auto &input_channel : unready_queue_ids_) {
auto &channel_info = channel_info_map_[input_channel];
auto it = channel_map_.find(input_channel);
if (it != channel_map_.end()) {
STREAMING_LOG(INFO) << "Channel id " << input_channel << " has been initialized.";
continue;
}
std::shared_ptr<ConsumerChannel> channel;
if (runtime_context_->IsMockTest()) {
channel = std::make_shared<MockConsumer>(transfer_config_, channel_info);
Expand All @@ -86,7 +91,8 @@ StreamingStatus DataReader::InitChannel(
channel_map_.emplace(input_channel, channel);
TransferCreationStatus status = channel->CreateTransferChannel();
creation_status.push_back(status);
if (TransferCreationStatus::PullOk != status) {
if (TransferCreationStatus::DataLost == status ||
TransferCreationStatus::Timeout == status) {
STREAMING_LOG(ERROR) << "Initialize queue failed, id=" << input_channel
<< ", status=" << static_cast<uint32_t>(status);
}
Expand Down

0 comments on commit 65dec4a

Please sign in to comment.