Skip to content

Commit

Permalink
[FLINK-36831][table] Introduce AppendOnlyTopNFunction in Rank with As…
Browse files Browse the repository at this point in the history
…ync State API

This closes #25723
  • Loading branch information
xuyangzhong authored Jan 14, 2025
1 parent b71dd76 commit 29edd60
Show file tree
Hide file tree
Showing 12 changed files with 754 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.apache.flink.table.runtime.operators.rank.RetractableTopNFunction;
import org.apache.flink.table.runtime.operators.rank.UpdatableTopNFunction;
import org.apache.flink.table.runtime.operators.rank.async.AbstractAsyncStateTopNFunction;
import org.apache.flink.table.runtime.operators.rank.async.AsyncStateAppendOnlyTopNFunction;
import org.apache.flink.table.runtime.operators.rank.async.AsyncStateFastTop1Function;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils;
Expand Down Expand Up @@ -283,17 +284,31 @@ protected Transformation<RowData> translateToPlanInternal(
cacheSize);
}
} else {
processFunction =
new AppendOnlyTopNFunction(
ttlConfig,
inputRowTypeInfo,
sortKeyComparator,
sortKeySelector,
rankType,
rankRange,
generateUpdateBefore,
outputRankNumber,
cacheSize);
if (isAsyncStateEnabled) {
processFunction =
new AsyncStateAppendOnlyTopNFunction(
ttlConfig,
inputRowTypeInfo,
sortKeyComparator,
sortKeySelector,
rankType,
rankRange,
generateUpdateBefore,
outputRankNumber,
cacheSize);
} else {
processFunction =
new AppendOnlyTopNFunction(
ttlConfig,
inputRowTypeInfo,
sortKeyComparator,
sortKeySelector,
rankType,
rankRange,
generateUpdateBefore,
outputRankNumber,
cacheSize);
}
}
} else if (rankStrategy instanceof RankProcessStrategy.UpdateFastStrategy) {
if (RankUtil.isTop1(rankRange)) {
Expand Down Expand Up @@ -344,7 +359,6 @@ protected Transformation<RowData> translateToPlanInternal(
outputRankNumber,
cacheSize);
}
// TODO Use UnaryUpdateTopNFunction after SortedMapState is merged
} else if (rankStrategy instanceof RankProcessStrategy.RetractStrategy) {
EqualiserCodeGenerator equaliserCodeGen =
new EqualiserCodeGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.flink.table.planner.runtime.utils.{JavaUserDefinedTableFunctio
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode}
import org.apache.flink.table.runtime.util.RowDataHarnessAssertor
import org.apache.flink.table.runtime.util.StreamRecordUtils.binaryRecord
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension, Parameters}
import org.apache.flink.types.Row
import org.apache.flink.types.RowKind._
Expand Down Expand Up @@ -468,7 +469,10 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)
testHarness.close()
}

def prepareTop1Tester(query: String, operatorNameIdentifier: String)
def prepareRankTester(
query: String,
operatorNameIdentifier: String,
operatorOutputLogicalTypes: Array[LogicalType])
: (KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData], RowDataHarnessAssertor) = {
val sourceDDL =
s"""
Expand All @@ -486,11 +490,7 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)

val testHarness =
createHarnessTester(t1.toRetractStream[Row], operatorNameIdentifier)
val assertor = new RowDataHarnessAssertor(
Array(
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType))
val assertor = new RowDataHarnessAssertor(operatorOutputLogicalTypes)

(testHarness, assertor)
}
Expand All @@ -500,7 +500,7 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)
tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(1))
val query =
"""
|SELECT a, b, rn
|SELECT a, b
|FROM
|(
| SELECT a, b,
Expand All @@ -510,7 +510,11 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)
|WHERE rn <= 1
""".stripMargin
val (testHarness, assertor) =
prepareTop1Tester(query, "Rank(strategy=[AppendFastStrategy")
prepareRankTester(
query,
"Rank(strategy=[AppendFastStrategy",
Array(DataTypes.STRING().getLogicalType, DataTypes.BIGINT().getLogicalType)
)

if (enableAsyncState) {
assertThat(isAsyncStateOperator(testHarness)).isTrue
Expand Down Expand Up @@ -541,7 +545,7 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)
tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(1))
val query =
"""
|SELECT a, b, rn
|SELECT a, b
|FROM
|(
| SELECT a, b,
Expand All @@ -553,7 +557,11 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)
|WHERE rn <= 1
""".stripMargin
val (testHarness, assertor) =
prepareTop1Tester(query, "Rank(strategy=[UpdateFastStrategy")
prepareRankTester(
query,
"Rank(strategy=[UpdateFastStrategy",
Array(DataTypes.STRING().getLogicalType, DataTypes.BIGINT().getLogicalType)
)

if (enableAsyncState) {
assertThat(isAsyncStateOperator(testHarness)).isTrue
Expand All @@ -580,6 +588,143 @@ class RankHarnessTest(mode: StateBackendMode, enableAsyncState: Boolean)

testHarness.close()
}

@TestTemplate
def testAppendOnlyTopNWithRowNumber(): Unit = {
tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(1))
val query =
"""
|SELECT a, b, rn
|FROM
|(
| SELECT a, b,
| ROW_NUMBER() OVER (PARTITION BY a ORDER BY b DESC) AS rn
| FROM T
|) t1
|WHERE rn <= 3
""".stripMargin
val (testHarness, assertor) =
prepareRankTester(
query,
"Rank(strategy=[AppendFastStrategy",
Array(
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType)
)

if (enableAsyncState) {
assertThat(isAsyncStateOperator(testHarness)).isTrue
} else {
assertThat(isAsyncStateOperator(testHarness)).isFalse
}

testHarness.open()

val expectedOutput = new ConcurrentLinkedQueue[Object]()

// a,2 - top1
testHarness.processElement(binaryRecord(INSERT, "a", 2L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 2L: JLong, 1L: JLong))

// a,2 - top1
// a,1 - top2
testHarness.processElement(binaryRecord(INSERT, "a", 1L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 1L: JLong, 2L: JLong))

// a,3 - top1
// a,2 - top2
// a,1 - top3
testHarness.processElement(binaryRecord(INSERT, "a", 3L: JLong))
expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 2L: JLong, 1L: JLong))
expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 3L: JLong, 1L: JLong))
expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 1L: JLong, 2L: JLong))
expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 2L: JLong, 2L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 1L: JLong, 3L: JLong))

// a,3 - top1
// a,2 - top2
// a,1 - top3
testHarness.processElement(binaryRecord(INSERT, "a", 0L: JLong))

// a,3 - top1
// a,3 - top2
// a,2 - top3
testHarness.processElement(binaryRecord(INSERT, "a", 3L: JLong))
expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 2L: JLong, 2L: JLong))
expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 3L: JLong, 2L: JLong))
expectedOutput.add(binaryRecord(UPDATE_BEFORE, "a", 1L: JLong, 3L: JLong))
expectedOutput.add(binaryRecord(UPDATE_AFTER, "a", 2L: JLong, 3L: JLong))

val result = dropWatermarks(testHarness.getOutput.toArray)
assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)

testHarness.close()
}

@TestTemplate
def testAppendOnlyTopNWithoutRowNumber(): Unit = {
tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(1))
val query =
"""
|SELECT a, b
|FROM
|(
| SELECT a, b,
| ROW_NUMBER() OVER (PARTITION BY a ORDER BY b DESC) AS rn
| FROM T
|) t1
|WHERE rn <= 3
""".stripMargin
val (testHarness, assertor) =
prepareRankTester(
query,
"Rank(strategy=[AppendFastStrategy",
Array(DataTypes.STRING().getLogicalType, DataTypes.BIGINT().getLogicalType)
)

if (enableAsyncState) {
assertThat(isAsyncStateOperator(testHarness)).isTrue
} else {
assertThat(isAsyncStateOperator(testHarness)).isFalse
}

testHarness.open()

val expectedOutput = new ConcurrentLinkedQueue[Object]()

// a,2 - top1
testHarness.processElement(binaryRecord(INSERT, "a", 2L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 2L: JLong))

// a,2 - top1
// a,1 - top2
testHarness.processElement(binaryRecord(INSERT, "a", 1L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 1L: JLong))

// a,3 - top1
// a,2 - top2
// a,1 - top3
testHarness.processElement(binaryRecord(INSERT, "a", 3L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 3L: JLong))

// a,3 - top1
// a,2 - top2
// a,1 - top3
testHarness.processElement(binaryRecord(INSERT, "a", 0L: JLong))

// a,3 - top1
// a,3 - top2
// a,2 - top3
testHarness.processElement(binaryRecord(INSERT, "a", 3L: JLong))
expectedOutput.add(binaryRecord(DELETE, "a", 1L: JLong))
expectedOutput.add(binaryRecord(INSERT, "a", 3L: JLong))

val result = dropWatermarks(testHarness.getOutput.toArray)
assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)

testHarness.close()
}
}

object RankHarnessTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@
import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.util.Collector;

import java.util.Objects;

/** Base class for TopN Function with sync state api. */
public abstract class AbstractSyncStateTopNFunction extends AbstractTopNFunction {

private ValueState<Long> rankEndState;

protected long rankEnd;

public AbstractSyncStateTopNFunction(
StateTtlConfig ttlConfig,
InternalTypeInfo<RowData> inputRowType,
Expand Down Expand Up @@ -76,6 +81,7 @@ public void open(OpenContext openContext) throws Exception {
*/
protected long initRankEnd(RowData row) throws Exception {
if (isConstantRankEnd) {
rankEnd = Objects.requireNonNull(constantRankEnd);
return rankEnd;
} else {
Long rankEndValue = rankEndState.value();
Expand All @@ -95,4 +101,30 @@ protected long initRankEnd(RowData row) throws Exception {
}
}
}

// ====== utility methods that omit the specified rank end ======

protected boolean isInRankEnd(long rank) {
return rank <= rankEnd;
}

protected boolean isInRankRange(long rank) {
return rank <= rankEnd && rank >= rankStart;
}

protected void collectInsert(Collector<RowData> out, RowData inputRow, long rank) {
collectInsert(out, inputRow, rank, rankEnd);
}

protected void collectDelete(Collector<RowData> out, RowData inputRow, long rank) {
collectDelete(out, inputRow, rank, rankEnd);
}

protected void collectUpdateAfter(Collector<RowData> out, RowData inputRow, long rank) {
collectUpdateAfter(out, inputRow, rank, rankEnd);
}

protected void collectUpdateBefore(Collector<RowData> out, RowData inputRow, long rank) {
collectUpdateBefore(out, inputRow, rank, rankEnd);
}
}
Loading

0 comments on commit 29edd60

Please sign in to comment.