flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] kl0u closed pull request #6520: [FLINK-10097][DataStream API] Additional tests for StreamingFileSink
Date Thu, 01 Nov 2018 08:03:11 GMT
kl0u closed pull request #6520: [FLINK-10097][DataStream API] Additional tests for StreamingFileSink
URL: https://github.com/apache/flink/pull/6520
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Bucket.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Bucket.java
index 6187e6853dd..65a7628578c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Bucket.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Bucket.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.functions.sink.filesystem;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.fs.RecoverableWriter;
 import org.apache.flink.util.Preconditions;
@@ -26,6 +27,8 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -59,10 +62,11 @@
 
 	private final RollingPolicy<IN, BucketID> rollingPolicy;
 
-	private final Map<Long, List<RecoverableWriter.CommitRecoverable>> pendingPartsPerCheckpoint = new HashMap<>();
+	private final Map<Long, List<RecoverableWriter.CommitRecoverable>> pendingPartsPerCheckpoint;
 
 	private long partCounter;
 
+	@Nullable
 	private PartFileWriter<IN, BucketID> inProgressPart;
 
 	private List<RecoverableWriter.CommitRecoverable> pendingPartsForCurrentCheckpoint;
@@ -88,6 +92,7 @@ private Bucket(
 		this.rollingPolicy = Preconditions.checkNotNull(rollingPolicy);
 
 		this.pendingPartsForCurrentCheckpoint = new ArrayList<>();
+		this.pendingPartsPerCheckpoint = new HashMap<>();
 	}
 
 	/**
@@ -277,6 +282,24 @@ void onProcessingTime(long timestamp) throws IOException {
 		}
 	}
 
+	// --------------------------- Testing Methods -----------------------------
+
+	@VisibleForTesting
+	Map<Long, List<RecoverableWriter.CommitRecoverable>> getPendingPartsPerCheckpoint() {
+		return pendingPartsPerCheckpoint;
+	}
+
+	@Nullable
+	@VisibleForTesting
+	PartFileWriter<IN, BucketID> getInProgressPart() {
+		return inProgressPart;
+	}
+
+	@VisibleForTesting
+	List<RecoverableWriter.CommitRecoverable> getPendingPartsForCurrentCheckpoint() {
+		return pendingPartsForCurrentCheckpoint;
+	}
+
 	// --------------------------- Static Factory Methods -----------------------------
 
 	/**
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Buckets.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Buckets.java
index 2aca841f16d..d08bc2ac0c3 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Buckets.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/Buckets.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.functions.sink.filesystem;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
@@ -186,6 +187,10 @@ private void handleRestoredBucketState(final BucketState<BucketID> recoveredStat
 	}
 
 	private void updateActiveBucketId(final BucketID bucketId, final Bucket<IN, BucketID> restoredBucket) throws IOException {
+		if (!restoredBucket.isActive()) {
+			return;
+		}
+
 		final Bucket<IN, BucketID> bucket = activeBuckets.get(bucketId);
 		if (bucket != null) {
 			bucket.merge(restoredBucket);
@@ -224,6 +229,9 @@ void snapshotState(
 		LOG.info("Subtask {} checkpointing for checkpoint with id={} (max part counter={}).",
 				subtaskIndex, checkpointId, maxPartCounter);
 
+		bucketStatesContainer.clear();
+		partCounterStateContainer.clear();
+
 		snapshotActiveBuckets(checkpointId, bucketStatesContainer);
 		partCounterStateContainer.add(maxPartCounter);
 	}
@@ -341,4 +349,16 @@ public Long timestamp() {
 			return elementTimestamp;
 		}
 	}
+
+	// --------------------------- Testing Methods -----------------------------
+
+	@VisibleForTesting
+	public long getMaxPartCounter() {
+		return maxPartCounter;
+	}
+
+	@VisibleForTesting
+	Map<BucketID, Bucket<IN, BucketID>> getActiveBuckets() {
+		return activeBuckets;
+	}
 }
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/StreamingFileSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/StreamingFileSink.java
index 6f57fee81f3..dc0b1c6e8a3 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/StreamingFileSink.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/StreamingFileSink.java
@@ -344,9 +344,6 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception {
 	public void snapshotState(FunctionSnapshotContext context) throws Exception {
 		Preconditions.checkState(bucketStates != null && maxPartCountersState != null, "sink has not been initialized");
 
-		bucketStates.clear();
-		maxPartCountersState.clear();
-
 		buckets.snapshotState(
 				context.getCheckpointId(),
 				bucketStates,
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/rollingpolicies/DefaultRollingPolicy.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/rollingpolicies/DefaultRollingPolicy.java
index 7c75f1c5e25..f890326d147 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/rollingpolicies/DefaultRollingPolicy.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/filesystem/rollingpolicies/DefaultRollingPolicy.java
@@ -78,8 +78,8 @@ public boolean shouldRollOnEvent(PartFileInfo<BucketID> partFileState, IN elemen
 
 	@Override
 	public boolean shouldRollOnProcessingTime(final PartFileInfo<BucketID> partFileState, final long currentTime) {
-		return currentTime - partFileState.getCreationTime() > rolloverInterval ||
-				currentTime - partFileState.getLastUpdateTime() > inactivityInterval;
+		return currentTime - partFileState.getCreationTime() >= rolloverInterval ||
+				currentTime - partFileState.getLastUpdateTime() >= inactivityInterval;
 	}
 
 	/**
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/BucketsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/BucketsTest.java
index 25622d14466..aee362178a7 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/BucketsTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/BucketsTest.java
@@ -19,17 +19,26 @@
 package org.apache.flink.streaming.api.functions.sink.filesystem;
 
 import org.apache.flink.api.common.serialization.SimpleStringEncoder;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
-import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.sink.filesystem.TestUtils.MockListState;
 import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.SimpleVersionedStringSerializer;
+import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.DefaultRollingPolicy;
+import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
 
+import org.hamcrest.Description;
+import org.hamcrest.TypeSafeMatcher;
 import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
 import java.io.File;
+import java.io.IOException;
+import java.util.Map;
+
+import static org.hamcrest.MatcherAssert.assertThat;
 
 /**
  * Tests for {@link Buckets}.
@@ -40,46 +49,287 @@
 	public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder();
 
 	@Test
-	public void testContextPassingNormalExecution() throws Exception {
-		testCorrectPassingOfContext(1L, 2L, 3L);
-	}
+	public void testSnapshotAndRestore() throws Exception {
+		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
 
-	@Test
-	public void testContextPassingNullTimestamp() throws Exception {
-		testCorrectPassingOfContext(null, 2L, 3L);
-	}
+		final RollingPolicy<String, String> onCheckpointRollingPolicy = OnCheckpointRollingPolicy.build();
 
-	private void testCorrectPassingOfContext(Long timestamp, long watermark, long processingTime) throws Exception {
-		final File outDir = TEMP_FOLDER.newFolder();
+		final Buckets<String, String> buckets = createBuckets(path, onCheckpointRollingPolicy, 0);
+
+		final ListState<byte[]> bucketStateContainer = new MockListState<>();
+		final ListState<Long> partCounterContainer = new MockListState<>();
+
+		buckets.onElement("test1", new TestUtils.MockSinkContext(null, 1L, 2L));
+		buckets.snapshotState(0L, bucketStateContainer, partCounterContainer);
+
+		assertThat(buckets.getActiveBuckets().get("test1"), hasSinglePartFileToBeCommittedOnCheckpointAck(path, "test1"));
+
+		buckets.onElement("test2", new TestUtils.MockSinkContext(null, 1L, 2L));
+		buckets.snapshotState(1L, bucketStateContainer, partCounterContainer);
+
+		assertThat(buckets.getActiveBuckets().get("test1"), hasSinglePartFileToBeCommittedOnCheckpointAck(path, "test1"));
+		assertThat(buckets.getActiveBuckets().get("test2"), hasSinglePartFileToBeCommittedOnCheckpointAck(path, "test2"));
+
+		Buckets<String, String> restoredBuckets =
+				restoreBuckets(path, onCheckpointRollingPolicy, 0, bucketStateContainer, partCounterContainer);
 
-		final Long expectedTimestamp = timestamp;
-		final long expectedWatermark = watermark;
-		final long expectedProcessingTime = processingTime;
+		final Map<String, Bucket<String, String>> activeBuckets = restoredBuckets.getActiveBuckets();
 
-		final Buckets<String, String> buckets = StreamingFileSink
-				.<String>forRowFormat(new Path(outDir.toURI()), new SimpleStringEncoder<>())
-				.withBucketAssigner(new VarifyingBucketer(expectedTimestamp, expectedWatermark, expectedProcessingTime))
-				.createBuckets(2);
+		// because we commit pending files for previous checkpoints upon recovery
+		Assert.assertTrue(activeBuckets.isEmpty());
+	}
 
-		buckets.onElement("TEST", new SinkFunction.Context() {
+	private static TypeSafeMatcher<Bucket<String, String>> hasSinglePartFileToBeCommittedOnCheckpointAck(final Path testTmpPath, final String bucketId) {
+		return new TypeSafeMatcher<Bucket<String, String>>() {
 			@Override
-			public long currentProcessingTime() {
-				return expectedProcessingTime;
+			protected boolean matchesSafely(Bucket<String, String> bucket) {
+				return bucket.getBucketId().equals(bucketId) &&
+						bucket.getBucketPath().equals(new Path(testTmpPath, bucketId)) &&
+						bucket.getInProgressPart() == null &&
+						bucket.getPendingPartsForCurrentCheckpoint().isEmpty() &&
+						bucket.getPendingPartsPerCheckpoint().size() == 1;
 			}
 
 			@Override
-			public long currentWatermark() {
-				return expectedWatermark;
+			public void describeTo(Description description) {
+				description.appendText("a Bucket with a single pending part file @ ")
+						.appendValue(new Path(testTmpPath, bucketId))
+						.appendText("'");
 			}
+		};
+	}
 
-			@Override
-			public Long timestamp() {
-				return expectedTimestamp;
+	@Test
+	public void testMergeAtScaleInAndMaxCounterAtRecovery() throws Exception {
+		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
+
+		final RollingPolicy<String, String> onCheckpointRP =
+				DefaultRollingPolicy
+						.create()
+						.withMaxPartSize(7L) // roll with 2 elements
+						.build();
+
+		final MockListState<byte[]> bucketStateContainerOne = new MockListState<>();
+		final MockListState<byte[]> bucketStateContainerTwo = new MockListState<>();
+
+		final MockListState<Long> partCounterContainerOne = new MockListState<>();
+		final MockListState<Long> partCounterContainerTwo = new MockListState<>();
+
+		final Buckets<String, String> bucketsOne = createBuckets(path, onCheckpointRP, 0);
+		final Buckets<String, String> bucketsTwo = createBuckets(path, onCheckpointRP, 1);
+
+		bucketsOne.onElement("test1", new TestUtils.MockSinkContext(null, 1L, 2L));
+		bucketsOne.snapshotState(0L, bucketStateContainerOne, partCounterContainerOne);
+
+		Assert.assertEquals(1L, bucketsOne.getMaxPartCounter());
+
+		// make sure we have one in-progress file here
+		Assert.assertNotNull(bucketsOne.getActiveBuckets().get("test1").getInProgressPart());
+
+		// add a couple of in-progress files so that the part counter increases.
+		bucketsTwo.onElement("test1", new TestUtils.MockSinkContext(null, 1L, 2L));
+		bucketsTwo.onElement("test1", new TestUtils.MockSinkContext(null, 1L, 2L));
+
+		bucketsTwo.onElement("test1", new TestUtils.MockSinkContext(null, 1L, 2L));
+
+		bucketsTwo.snapshotState(0L, bucketStateContainerTwo, partCounterContainerTwo);
+
+		Assert.assertEquals(2L, bucketsTwo.getMaxPartCounter());
+
+		// make sure we have one in-progress file here and a pending
+		Assert.assertEquals(1L, bucketsTwo.getActiveBuckets().get("test1").getPendingPartsPerCheckpoint().size());
+		Assert.assertNotNull(bucketsTwo.getActiveBuckets().get("test1").getInProgressPart());
+
+		final ListState<byte[]> mergedBucketStateContainer = new MockListState<>();
+		final ListState<Long> mergedPartCounterContainer = new MockListState<>();
+
+		mergedBucketStateContainer.addAll(bucketStateContainerOne.getBackingList());
+		mergedBucketStateContainer.addAll(bucketStateContainerTwo.getBackingList());
+
+		mergedPartCounterContainer.addAll(partCounterContainerOne.getBackingList());
+		mergedPartCounterContainer.addAll(partCounterContainerTwo.getBackingList());
+
+		final Buckets<String, String> restoredBuckets =
+				restoreBuckets(path, onCheckpointRP, 0, mergedBucketStateContainer, mergedPartCounterContainer);
+
+		// we get the maximum of the previous tasks
+		Assert.assertEquals(2L, restoredBuckets.getMaxPartCounter());
+
+		final Map<String, Bucket<String, String>> activeBuckets = restoredBuckets.getActiveBuckets();
+		Assert.assertEquals(1L, activeBuckets.size());
+		Assert.assertTrue(activeBuckets.keySet().contains("test1"));
+
+		final Bucket<String, String> bucket = activeBuckets.get("test1");
+		Assert.assertEquals("test1", bucket.getBucketId());
+		Assert.assertEquals(new Path(path, "test1"), bucket.getBucketPath());
+
+		Assert.assertNotNull(bucket.getInProgressPart()); // the restored part file
+
+		// this is due to the Bucket#merge(). The in progress file of one
+		// of the previous tasks is put in the list of pending files.
+		Assert.assertEquals(1L, bucket.getPendingPartsForCurrentCheckpoint().size());
+
+		// we commit the pending for previous checkpoints
+		Assert.assertTrue(bucket.getPendingPartsPerCheckpoint().isEmpty());
+	}
+
+	@Test
+	public void testOnProcessingTime() throws Exception {
+		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
+
+		final OnProcessingTimePolicy<String, String> rollOnProcessingTimeCountingPolicy =
+				new OnProcessingTimePolicy<>(2L);
+
+		final Buckets<String, String> buckets =
+				createBuckets(path, rollOnProcessingTimeCountingPolicy, 0);
+
+		// it takes the current processing time of the context for the creation time,
+		// and for the last modification time.
+		buckets.onElement("test", new TestUtils.MockSinkContext(1L, 2L , 3L));
+
+		// now it should roll
+		buckets.onProcessingTime(7L);
+		Assert.assertEquals(1L, rollOnProcessingTimeCountingPolicy.getOnProcessingTimeRollCounter());
+
+		final Map<String, Bucket<String, String>> activeBuckets = buckets.getActiveBuckets();
+		Assert.assertEquals(1L, activeBuckets.size());
+		Assert.assertTrue(activeBuckets.keySet().contains("test"));
+
+		final Bucket<String, String> bucket = activeBuckets.get("test");
+		Assert.assertEquals("test", bucket.getBucketId());
+		Assert.assertEquals(new Path(path, "test"), bucket.getBucketPath());
+		Assert.assertEquals("test", bucket.getBucketId());
+
+		Assert.assertNull(bucket.getInProgressPart());
+		Assert.assertEquals(1L, bucket.getPendingPartsForCurrentCheckpoint().size());
+		Assert.assertTrue(bucket.getPendingPartsPerCheckpoint().isEmpty());
+	}
+
+	@Test
+	public void testBucketIsRemovedWhenNotActive() throws Exception {
+		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
+
+		final OnProcessingTimePolicy<String, String> rollOnProcessingTimeCountingPolicy =
+				new OnProcessingTimePolicy<>(2L);
+
+		final Buckets<String, String> buckets =
+				createBuckets(path, rollOnProcessingTimeCountingPolicy, 0);
+
+		// it takes the current processing time of the context for the creation time, and for the last modification time.
+		buckets.onElement("test", new TestUtils.MockSinkContext(1L, 2L , 3L));
+
+		// now it should roll
+		buckets.onProcessingTime(7L);
+		Assert.assertEquals(1L, rollOnProcessingTimeCountingPolicy.getOnProcessingTimeRollCounter());
+
+		buckets.snapshotState(0L, new MockListState<>(), new MockListState<>());
+		buckets.commitUpToCheckpoint(0L);
+
+		Assert.assertTrue(buckets.getActiveBuckets().isEmpty());
+	}
+
+	@Test
+	public void testPartCounterAfterBucketResurrection() throws Exception {
+		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
+
+		final OnProcessingTimePolicy<String, String> rollOnProcessingTimeCountingPolicy =
+				new OnProcessingTimePolicy<>(2L);
+
+		final Buckets<String, String> buckets =
+				createBuckets(path, rollOnProcessingTimeCountingPolicy, 0);
+
+		// it takes the current processing time of the context for the creation time, and for the last modification time.
+		buckets.onElement("test", new TestUtils.MockSinkContext(1L, 2L , 3L));
+		Assert.assertEquals(1L, buckets.getActiveBuckets().get("test").getPartCounter());
+
+		// now it should roll
+		buckets.onProcessingTime(7L);
+		Assert.assertEquals(1L, rollOnProcessingTimeCountingPolicy.getOnProcessingTimeRollCounter());
+		Assert.assertEquals(1L, buckets.getActiveBuckets().get("test").getPartCounter());
+
+		buckets.snapshotState(0L, new MockListState<>(), new MockListState<>());
+		buckets.commitUpToCheckpoint(0L);
+
+		Assert.assertTrue(buckets.getActiveBuckets().isEmpty());
+
+		buckets.onElement("test", new TestUtils.MockSinkContext(2L, 3L , 4L));
+		Assert.assertEquals(2L, buckets.getActiveBuckets().get("test").getPartCounter());
+	}
+
+	private static class OnProcessingTimePolicy<IN, BucketID> implements RollingPolicy<IN, BucketID> {
+
+		private static final long serialVersionUID = 1L;
+
+		private int onProcessingTimeRollCounter = 0;
+
+		private final long rolloverInterval;
+
+		OnProcessingTimePolicy(final long rolloverInterval) {
+			this.rolloverInterval = rolloverInterval;
+		}
+
+		public int getOnProcessingTimeRollCounter() {
+			return onProcessingTimeRollCounter;
+		}
+
+		@Override
+		public boolean shouldRollOnCheckpoint(PartFileInfo<BucketID> partFileState) {
+			return false;
+		}
+
+		@Override
+		public boolean shouldRollOnEvent(PartFileInfo<BucketID> partFileState, IN element) {
+			return false;
+		}
+
+		@Override
+		public boolean shouldRollOnProcessingTime(PartFileInfo<BucketID> partFileState, long currentTime) {
+			boolean result = currentTime - partFileState.getCreationTime() >= rolloverInterval;
+			if (result) {
+				onProcessingTimeRollCounter++;
 			}
-		});
+			return result;
+		}
 	}
 
-	private static class VarifyingBucketer implements BucketAssigner<String, String> {
+	@Test
+	public void testContextPassingNormalExecution() throws Exception {
+		testCorrectTimestampPassingInContext(1L, 2L, 3L);
+	}
+
+	@Test
+	public void testContextPassingNullTimestamp() throws Exception {
+		testCorrectTimestampPassingInContext(null, 2L, 3L);
+	}
+
+	private void testCorrectTimestampPassingInContext(Long timestamp, long watermark, long processingTime) throws Exception {
+		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
+
+		final Buckets<String, String> buckets = new Buckets<>(
+				path,
+				new VerifyingBucketAssigner(timestamp, watermark, processingTime),
+				new DefaultBucketFactoryImpl<>(),
+				new RowWisePartWriter.Factory<>(new SimpleStringEncoder<>()),
+				DefaultRollingPolicy.create().build(),
+				2
+		);
+
+		buckets.onElement(
+				"test",
+				new TestUtils.MockSinkContext(
+						timestamp,
+						watermark,
+						processingTime)
+		);
+	}
+
+	private static class VerifyingBucketAssigner implements BucketAssigner<String, String> {
 
 		private static final long serialVersionUID = 7729086510972377578L;
 
@@ -87,7 +337,7 @@ public Long timestamp() {
 		private final long expectedWatermark;
 		private final long expectedProcessingTime;
 
-		VarifyingBucketer(
+		VerifyingBucketAssigner(
 				final Long expectedTimestamp,
 				final long expectedWatermark,
 				final long expectedProcessingTime
@@ -98,7 +348,7 @@ public Long timestamp() {
 		}
 
 		@Override
-		public String getBucketId(String element, Context context) {
+		public String getBucketId(String element, BucketAssigner.Context context) {
 			final Long elementTimestamp = context.timestamp();
 			final long watermark = context.currentWatermark();
 			final long processingTime = context.currentProcessingTime();
@@ -115,4 +365,35 @@ public String getBucketId(String element, Context context) {
 			return SimpleVersionedStringSerializer.INSTANCE;
 		}
 	}
+
+	// ------------------------------- Utility Methods --------------------------------
+
+	private static Buckets<String, String> createBuckets(
+			final Path basePath,
+			final RollingPolicy<String, String> rollingPolicy,
+			final int subtaskIdx
+	) throws IOException {
+
+		return new Buckets<>(
+				basePath,
+				new TestUtils.StringIdentityBucketAssigner(),
+				new DefaultBucketFactoryImpl<>(),
+				new RowWisePartWriter.Factory<>(new SimpleStringEncoder<>()),
+				rollingPolicy,
+				subtaskIdx
+		);
+	}
+
+	private static Buckets<String, String> restoreBuckets(
+			final Path basePath,
+			final RollingPolicy<String, String> rollingPolicy,
+			final int subtaskIdx,
+			final ListState<byte[]> bucketState,
+			final ListState<Long> partCounterState
+	) throws Exception {
+
+		final Buckets<String, String> restoredBuckets = createBuckets(basePath, rollingPolicy, subtaskIdx);
+		restoredBuckets.initializeState(bucketState, partCounterState);
+		return restoredBuckets;
+	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/LocalStreamingFileSinkTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/LocalStreamingFileSinkTest.java
index 8bb35ff244a..9c3f946829b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/LocalStreamingFileSinkTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/LocalStreamingFileSinkTest.java
@@ -18,12 +18,8 @@
 
 package org.apache.flink.streaming.api.functions.sink.filesystem;
 
-import org.apache.flink.api.common.serialization.SimpleStringEncoder;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.fs.RecoverableWriter;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
-import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.DefaultRollingPolicy;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
@@ -35,7 +31,6 @@
 import org.junit.rules.TemporaryFolder;
 
 import java.io.File;
-import java.io.IOException;
 import java.util.Map;
 
 /**
@@ -494,131 +489,4 @@ public void testScalingDownAndMergingOfStates() throws Exception {
 			Assert.assertEquals(3L, counter);
 		}
 	}
-
-	@Test
-	public void testMaxCounterUponRecovery() throws Exception {
-		final File outDir = TEMP_FOLDER.newFolder();
-
-		OperatorSubtaskState mergedSnapshot;
-
-		final TestBucketFactoryImpl first = new TestBucketFactoryImpl();
-		final TestBucketFactoryImpl second = new TestBucketFactoryImpl();
-
-		final RollingPolicy<Tuple2<String, Integer>, String> rollingPolicy = DefaultRollingPolicy
-				.create()
-				.withMaxPartSize(2L)
-				.withRolloverInterval(100L)
-				.withInactivityInterval(100L)
-				.build();
-
-		try (
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness1 = TestUtils.createCustomRescalingTestSink(
-						outDir, 2, 0, 10L, new TestUtils.TupleToStringBucketer(), new SimpleStringEncoder<>(), rollingPolicy, first);
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness2 = TestUtils.createCustomRescalingTestSink(
-						outDir, 2, 1, 10L, new TestUtils.TupleToStringBucketer(), new SimpleStringEncoder<>(), rollingPolicy, second)
-		) {
-			testHarness1.setup();
-			testHarness1.open();
-
-			testHarness2.setup();
-			testHarness2.open();
-
-			// we only put elements in one task.
-			testHarness1.processElement(new StreamRecord<>(Tuple2.of("test1", 0), 0L));
-			testHarness1.processElement(new StreamRecord<>(Tuple2.of("test1", 0), 0L));
-			testHarness1.processElement(new StreamRecord<>(Tuple2.of("test1", 0), 0L));
-			TestUtils.checkLocalFs(outDir, 3, 0);
-
-			// intentionally we snapshot them in the reverse order so that the states are shuffled
-			mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState(
-					testHarness2.snapshot(0L, 0L),
-					testHarness1.snapshot(0L, 0L)
-			);
-		}
-
-		final TestBucketFactoryImpl firstRecovered = new TestBucketFactoryImpl();
-		final TestBucketFactoryImpl secondRecovered = new TestBucketFactoryImpl();
-
-		try (
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness1 = TestUtils.createCustomRescalingTestSink(
-						outDir, 2, 0, 10L, new TestUtils.TupleToStringBucketer(), new SimpleStringEncoder<>(), rollingPolicy, firstRecovered);
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness2 = TestUtils.createCustomRescalingTestSink(
-						outDir, 2, 1, 10L, new TestUtils.TupleToStringBucketer(), new SimpleStringEncoder<>(), rollingPolicy, secondRecovered)
-		) {
-			testHarness1.setup();
-			testHarness1.initializeState(mergedSnapshot);
-			testHarness1.open();
-
-			// we have to send an element so that the factory updates its counter.
-			testHarness1.processElement(new StreamRecord<>(Tuple2.of("test4", 0), 0L));
-
-			Assert.assertEquals(3L, firstRecovered.getInitialCounter());
-			TestUtils.checkLocalFs(outDir, 1, 3);
-
-			testHarness2.setup();
-			testHarness2.initializeState(mergedSnapshot);
-			testHarness2.open();
-
-			// we have to send an element so that the factory updates its counter.
-			testHarness2.processElement(new StreamRecord<>(Tuple2.of("test2", 0), 0L));
-
-			Assert.assertEquals(3L, secondRecovered.getInitialCounter());
-			TestUtils.checkLocalFs(outDir, 2, 3);
-		}
-	}
-
-	//////////////////////			Helper Methods			//////////////////////
-
-	static class TestBucketFactoryImpl extends DefaultBucketFactoryImpl<Tuple2<String, Integer>, String> {
-
-		private static final long serialVersionUID = 2794824980604027930L;
-
-		private long initialCounter = -1L;
-
-		@Override
-		public Bucket<Tuple2<String, Integer>, String> getNewBucket(
-				final RecoverableWriter fsWriter,
-				final int subtaskIndex,
-				final String bucketId,
-				final Path bucketPath,
-				final long initialPartCounter,
-				final PartFileWriter.PartFileFactory<Tuple2<String, Integer>, String> partFileWriterFactory,
-				final RollingPolicy<Tuple2<String, Integer>, String> rollingPolicy) {
-
-			this.initialCounter = initialPartCounter;
-
-			return super.getNewBucket(
-					fsWriter,
-					subtaskIndex,
-					bucketId,
-					bucketPath,
-					initialPartCounter,
-					partFileWriterFactory,
-					rollingPolicy);
-		}
-
-		@Override
-		public Bucket<Tuple2<String, Integer>, String> restoreBucket(
-				final RecoverableWriter fsWriter,
-				final int subtaskIndex,
-				final long initialPartCounter,
-				final PartFileWriter.PartFileFactory<Tuple2<String, Integer>, String> partFileWriterFactory,
-				final RollingPolicy<Tuple2<String, Integer>, String> rollingPolicy,
-				final BucketState<String> bucketState) throws IOException {
-
-			this.initialCounter = initialPartCounter;
-
-			return super.restoreBucket(
-					fsWriter,
-					subtaskIndex,
-					initialPartCounter,
-					partFileWriterFactory,
-					rollingPolicy,
-					bucketState);
-		}
-
-		public long getInitialCounter() {
-			return initialCounter;
-		}
-	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/RollingPolicyTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/RollingPolicyTest.java
index 851b6825d9a..1f4c7e5b20a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/RollingPolicyTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/RollingPolicyTest.java
@@ -19,13 +19,11 @@
 package org.apache.flink.streaming.api.functions.sink.filesystem;
 
 import org.apache.flink.api.common.serialization.SimpleStringEncoder;
-import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.DefaultRollingPolicy;
 import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.util.Preconditions;
 
-import org.apache.commons.io.FileUtils;
 import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
@@ -33,11 +31,6 @@
 
 import java.io.File;
 import java.io.IOException;
-import java.util.Objects;
-
-import static org.hamcrest.CoreMatchers.containsString;
-import static org.hamcrest.CoreMatchers.either;
-import static org.hamcrest.CoreMatchers.equalTo;
 
 /**
  * Tests for different {@link RollingPolicy rolling policies}.
@@ -50,233 +43,233 @@
 	@Test
 	public void testDefaultRollingPolicy() throws Exception {
 		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
 
-		final RollingPolicy<Tuple2<String, Integer>, String> rollingPolicy = DefaultRollingPolicy
-				.create()
-				.withMaxPartSize(10L)
-				.withInactivityInterval(4L)
-				.withRolloverInterval(11L)
-				.build();
-
-		try (
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness = TestUtils.createCustomRescalingTestSink(
-						outDir,
-						1,
-						0,
-						1L,
-						new TestUtils.TupleToStringBucketer(),
-						new SimpleStringEncoder<>(),
-						rollingPolicy,
-						new DefaultBucketFactoryImpl<>())
-		) {
-			testHarness.setup();
-			testHarness.open();
-
-			testHarness.setProcessingTime(0L);
+		final RollingPolicy<String, String> originalRollingPolicy =
+				DefaultRollingPolicy
+						.create()
+						.withMaxPartSize(10L)
+						.withInactivityInterval(4L)
+						.withRolloverInterval(11L)
+						.build();
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 1), 1L));
-			TestUtils.checkLocalFs(outDir, 1, 0);
+		final MethodCallCountingPolicyWrapper<String, String> rollingPolicy =
+				new MethodCallCountingPolicyWrapper<>(originalRollingPolicy);
 
-			// roll due to size
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 2), 2L));
-			TestUtils.checkLocalFs(outDir, 1, 0);
+		final Buckets<String, String> buckets = createBuckets(path, rollingPolicy);
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 3), 3L));
-			TestUtils.checkLocalFs(outDir, 2, 0);
+		rollingPolicy.verifyCallCounters(0L, 0L, 0L, 0L, 0L, 0L);
 
-			// roll due to inactivity
-			testHarness.setProcessingTime(7L);
+		// these two will fill up the first in-progress file and at the third it will roll ...
+		buckets.onElement("test1", new TestUtils.MockSinkContext(1L, 1L, 1L));
+		buckets.onElement("test1", new TestUtils.MockSinkContext(2L, 1L, 2L));
+		rollingPolicy.verifyCallCounters(0L, 0L, 1L, 0L, 0L, 0L);
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 4), 4L));
-			TestUtils.checkLocalFs(outDir, 3, 0);
+		buckets.onElement("test1", new TestUtils.MockSinkContext(3L, 1L, 3L));
+		rollingPolicy.verifyCallCounters(0L, 0L, 2L, 1L, 0L, 0L);
 
-			// roll due to rollover interval
-			testHarness.setProcessingTime(20L);
+		// still no time to roll
+		buckets.onProcessingTime(5L);
+		rollingPolicy.verifyCallCounters(0L, 0L, 2L, 1L, 1L, 0L);
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 5), 5L));
-			TestUtils.checkLocalFs(outDir, 4, 0);
+		// roll due to inactivity
+		buckets.onProcessingTime(7L);
+		rollingPolicy.verifyCallCounters(0L, 0L, 2L, 1L, 2L, 1L);
 
-			// we take a checkpoint but we should not roll.
-			testHarness.snapshot(1L, 1L);
+		buckets.onElement("test1", new TestUtils.MockSinkContext(3L, 1L, 3L));
 
-			TestUtils.checkLocalFs(outDir, 4, 0);
+		// roll due to rollover interval
+		buckets.onProcessingTime(20L);
+		rollingPolicy.verifyCallCounters(0L, 0L, 2L, 1L, 3L, 2L);
 
-			// acknowledge the checkpoint, so publish the 3 closed files, but not the open one.
-			testHarness.notifyOfCompletedCheckpoint(1L);
-			TestUtils.checkLocalFs(outDir, 1, 3);
-		}
+		// we take a checkpoint but we should not roll.
+		buckets.snapshotState(1L, new TestUtils.MockListState<>(), new TestUtils.MockListState<>());
+		rollingPolicy.verifyCallCounters(0L, 0L, 2L, 1L, 3L, 2L);
 	}
 
 	@Test
 	public void testRollOnCheckpointPolicy() throws Exception {
 		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
 
-		final RollingPolicy<Tuple2<String, Integer>, String> rollingPolicy = OnCheckpointRollingPolicy.build();
-
-		try (
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness = TestUtils.createCustomRescalingTestSink(
-						outDir,
-						1,
-						0,
-						10L,
-						new TestUtils.TupleToStringBucketer(),
-						new SimpleStringEncoder<>(),
-						rollingPolicy,
-						new DefaultBucketFactoryImpl<>())
-		) {
-			testHarness.setup();
-			testHarness.open();
+		final MethodCallCountingPolicyWrapper<String, String> rollingPolicy =
+				new MethodCallCountingPolicyWrapper<>(OnCheckpointRollingPolicy.build());
 
-			testHarness.setProcessingTime(0L);
+		final Buckets<String, String> buckets = createBuckets(path, rollingPolicy);
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test2", 1), 1L));
+		rollingPolicy.verifyCallCounters(0L, 0L, 0L, 0L, 0L, 0L);
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 1), 1L));
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 2), 2L));
-			TestUtils.checkLocalFs(outDir, 2, 0);
+		buckets.onElement("test1", new TestUtils.MockSinkContext(1L, 1L, 2L));
+		buckets.onElement("test1", new TestUtils.MockSinkContext(2L, 1L, 2L));
+		buckets.onElement("test1", new TestUtils.MockSinkContext(3L, 1L, 3L));
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 3), 3L));
-			TestUtils.checkLocalFs(outDir, 2, 0);
+		// ... we have a checkpoint so we roll ...
+		buckets.snapshotState(1L, new TestUtils.MockListState<>(), new TestUtils.MockListState<>());
+		rollingPolicy.verifyCallCounters(1L, 1L, 2L, 0L, 0L, 0L);
 
-			// we take a checkpoint so we roll.
-			testHarness.snapshot(1L, 1L);
+		// ... create a new in-progress file (before we had closed the last one so it was null)...
+		buckets.onElement("test1", new TestUtils.MockSinkContext(5L, 1L, 5L));
 
-			for (File file: FileUtils.listFiles(outDir, null, true)) {
-				if (Objects.equals(file.getParentFile().getName(), "test1")) {
-					Assert.assertTrue(file.getName().contains(".part-0-1.inprogress."));
-				} else if (Objects.equals(file.getParentFile().getName(), "test2")) {
-					Assert.assertTrue(file.getName().contains(".part-0-0.inprogress."));
-				}
-			}
+		// ... we have a checkpoint so we roll ...
+		buckets.snapshotState(2L, new TestUtils.MockListState<>(), new TestUtils.MockListState<>());
+		rollingPolicy.verifyCallCounters(2L, 2L, 2L, 0L, 0L, 0L);
 
-			// this will create a new part file
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 4), 4L));
-			TestUtils.checkLocalFs(outDir, 3, 0);
-
-			testHarness.notifyOfCompletedCheckpoint(1L);
-			for (File file: FileUtils.listFiles(outDir, null, true)) {
-				if (Objects.equals(file.getParentFile().getName(), "test1")) {
-					Assert.assertTrue(
-							file.getName().contains(".part-0-2.inprogress.") || file.getName().equals("part-0-1")
-					);
-				} else if (Objects.equals(file.getParentFile().getName(), "test2")) {
-					Assert.assertEquals("part-0-0", file.getName());
-				}
-			}
-
-			// and open and fill .part-0-2.inprogress
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 5), 5L));
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 6), 6L));
-			TestUtils.checkLocalFs(outDir, 1, 2);
-
-			// we take a checkpoint so we roll.
-			testHarness.snapshot(2L, 2L);
-
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test2", 7), 7L));
-			TestUtils.checkLocalFs(outDir, 2, 2);
-
-			for (File file: FileUtils.listFiles(outDir, null, true)) {
-				if (Objects.equals(file.getParentFile().getName(), "test1")) {
-					Assert.assertThat(
-							file.getName(),
-							either(containsString(".part-0-2.inprogress."))
-									.or(equalTo("part-0-1"))
-					);
-				} else if (Objects.equals(file.getParentFile().getName(), "test2")) {
-					Assert.assertThat(
-							file.getName(),
-							either(containsString(".part-0-3.inprogress."))
-									.or(equalTo("part-0-0"))
-					);
-				}
-			}
-
-			// we acknowledge the last checkpoint so we should publish all but the latest in-progress file
-			testHarness.notifyOfCompletedCheckpoint(2L);
-
-			TestUtils.checkLocalFs(outDir, 1, 3);
-			for (File file: FileUtils.listFiles(outDir, null, true)) {
-				if (Objects.equals(file.getParentFile().getName(), "test1")) {
-					Assert.assertThat(
-							file.getName(),
-							either(equalTo("part-0-2")).or(equalTo("part-0-1"))
-					);
-				} else if (Objects.equals(file.getParentFile().getName(), "test2")) {
-					Assert.assertThat(
-							file.getName(),
-							either(containsString(".part-0-3.inprogress."))
-									.or(equalTo("part-0-0"))
-					);
-				}
-			}
-		}
+		buckets.close();
 	}
 
 	@Test
 	public void testCustomRollingPolicy() throws Exception {
 		final File outDir = TEMP_FOLDER.newFolder();
+		final Path path = new Path(outDir.toURI());
 
-		final RollingPolicy<Tuple2<String, Integer>, String> rollingPolicy = new RollingPolicy<Tuple2<String, Integer>, String>() {
+		final MethodCallCountingPolicyWrapper<String, String> rollingPolicy = new MethodCallCountingPolicyWrapper<>(
+				new RollingPolicy<String, String>() {
 
-			private static final long serialVersionUID = 1L;
+					private static final long serialVersionUID = 1L;
 
-			@Override
-			public boolean shouldRollOnCheckpoint(PartFileInfo<String> partFileState) {
-				return true;
-			}
+					@Override
+					public boolean shouldRollOnCheckpoint(PartFileInfo<String> partFileState) {
+						return true;
+					}
 
-			@Override
-			public boolean shouldRollOnEvent(PartFileInfo<String> partFileState, Tuple2<String, Integer> element) throws IOException {
-				// this means that 2 elements will close the part file.
-				return partFileState.getSize() > 12L;
-			}
+					@Override
+					public boolean shouldRollOnEvent(PartFileInfo<String> partFileState, String element) throws IOException {
+						// this means that 2 elements will close the part file.
+						return partFileState.getSize() > 9L;
+					}
 
-			@Override
-			public boolean shouldRollOnProcessingTime(PartFileInfo<String> partFileState, long currentTime) {
-				return false;
-			}
-		};
-
-		try (
-				OneInputStreamOperatorTestHarness<Tuple2<String, Integer>, Object> testHarness = TestUtils.createCustomRescalingTestSink(
-						outDir,
-						1,
-						0,
-						10L,
-						new TestUtils.TupleToStringBucketer(),
-						new SimpleStringEncoder<>(),
-						rollingPolicy,
-						new DefaultBucketFactoryImpl<>())
-		) {
-			testHarness.setup();
-			testHarness.open();
+					@Override
+					public boolean shouldRollOnProcessingTime(PartFileInfo<String> partFileState, long currentTime) {
+						return currentTime - partFileState.getLastUpdateTime() >= 10L;
+					}
+				});
+
+		final Buckets<String, String> buckets = createBuckets(path, rollingPolicy);
 
-			testHarness.setProcessingTime(0L);
+		rollingPolicy.verifyCallCounters(0L, 0L, 0L, 0L, 0L, 0L);
 
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test2", 1), 1L));
+		// the following 2 elements will close a part file because of size...
+		buckets.onElement("test1", new TestUtils.MockSinkContext(1L, 1L, 2L));
+		buckets.onElement("test1", new TestUtils.MockSinkContext(2L, 1L, 2L));
 
-			// the following 2 elements will close a part file ...
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 1), 1L));
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 2), 2L));
+		// only one call because we have no open part file in the other incoming elements, so currentPartFile == null so we roll without checking the policy.
+		rollingPolicy.verifyCallCounters(0L, 0L, 1L, 0L, 0L, 0L);
 
-			// ... and this one will open a new ...
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 3), 2L));
-			TestUtils.checkLocalFs(outDir, 3, 0);
+		// ... and this one will trigger the roll and open a new part file...
+		buckets.onElement("test1", new TestUtils.MockSinkContext(2L, 1L, 2L));
+		rollingPolicy.verifyCallCounters(0L, 0L, 2L, 1L, 0L, 0L);
 
-			// ... and all open part files should close here.
-			testHarness.snapshot(1L, 1L);
+		// ... we have a checkpoint so we roll ...
+		buckets.snapshotState(1L, new TestUtils.MockListState<>(), new TestUtils.MockListState<>());
+		rollingPolicy.verifyCallCounters(1L, 1L, 2L, 1L, 0L, 0L);
 
-			// this will create and fill out a new part file
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 4), 4L));
-			testHarness.processElement(new StreamRecord<>(Tuple2.of("test1", 5), 5L));
-			TestUtils.checkLocalFs(outDir, 4, 0);
+		// ... create a new in-progress file (before we had closed the last one so it was null)...
+		buckets.onElement("test1", new TestUtils.MockSinkContext(2L, 1L, 5L));
 
-			// we take a checkpoint so we roll.
-			testHarness.snapshot(2L, 2L);
+		// ... last modification time is 5L, so now we DON'T roll but we check ...
+		buckets.onProcessingTime(12L);
+		rollingPolicy.verifyCallCounters(1L, 1L, 2L, 1L, 1L, 0L);
+
+		// ... last modification time is 5L, so now we roll
+		buckets.onProcessingTime(16L);
+		rollingPolicy.verifyCallCounters(1L, 1L, 2L, 1L, 2L, 1L);
+
+		buckets.close();
+	}
+
+	// ------------------------------- Utility Methods --------------------------------
+
+	private static Buckets<String, String> createBuckets(
+			final Path basePath,
+			final MethodCallCountingPolicyWrapper<String, String> rollingPolicyToTest
+	) throws IOException {
+
+		return new Buckets<>(
+				basePath,
+				new TestUtils.StringIdentityBucketAssigner(),
+				new DefaultBucketFactoryImpl<>(),
+				new RowWisePartWriter.Factory<>(new SimpleStringEncoder<>()),
+				rollingPolicyToTest,
+				0
+		);
+	}
 
-			// we acknowledge the first checkpoint so we should publish all but the latest in-progress file
-			testHarness.notifyOfCompletedCheckpoint(1L);
-			TestUtils.checkLocalFs(outDir, 1, 3);
+	/**
+	 * A wrapper of a {@link RollingPolicy} which counts how many times each method of the policy was called
+	 * and in how many of them it decided to roll.
+	 */
+	private static class MethodCallCountingPolicyWrapper<IN, BucketID> implements RollingPolicy<IN, BucketID> {
+
+		private static final long serialVersionUID = 1L;
+
+		private final RollingPolicy<IN, BucketID> originalPolicy;
+
+		private long onCheckpointCallCounter;
+		private long onCheckpointRollCounter;
+
+		private long onEventCallCounter;
+		private long onEventRollCounter;
+
+		private long onProcessingTimeCallCounter;
+		private long onProcessingTimeRollCounter;
+
+		MethodCallCountingPolicyWrapper(final RollingPolicy<IN, BucketID> policy) {
+			this.originalPolicy = Preconditions.checkNotNull(policy);
+
+			this.onCheckpointCallCounter = 0L;
+			this.onCheckpointRollCounter = 0L;
+
+			this.onEventCallCounter = 0L;
+			this.onEventRollCounter = 0L;
+
+			this.onProcessingTimeCallCounter = 0L;
+			this.onProcessingTimeRollCounter = 0L;
+		}
+
+		@Override
+		public boolean shouldRollOnCheckpoint(PartFileInfo<BucketID> partFileState) throws IOException {
+			final boolean shouldRoll = originalPolicy.shouldRollOnCheckpoint(partFileState);
+			this.onCheckpointCallCounter++;
+			if (shouldRoll) {
+				this.onCheckpointRollCounter++;
+			}
+			return shouldRoll;
+		}
+
+		@Override
+		public boolean shouldRollOnEvent(PartFileInfo<BucketID> partFileState, IN element) throws IOException {
+			final boolean shouldRoll = originalPolicy.shouldRollOnEvent(partFileState, element);
+			this.onEventCallCounter++;
+			if (shouldRoll) {
+				this.onEventRollCounter++;
+			}
+			return shouldRoll;
+		}
+
+		@Override
+		public boolean shouldRollOnProcessingTime(PartFileInfo<BucketID> partFileState, long currentTime) throws IOException {
+			final boolean shouldRoll = originalPolicy.shouldRollOnProcessingTime(partFileState, currentTime);
+			this.onProcessingTimeCallCounter++;
+			if (shouldRoll) {
+				this.onProcessingTimeRollCounter++;
+			}
+			return shouldRoll;
+		}
+
+		void verifyCallCounters(
+				final long onCheckpointCalls,
+				final long onCheckpointRolls,
+				final long onEventCalls,
+				final long onEventRolls,
+				final long onProcessingTimeCalls,
+				final long onProcessingTimeRolls
+		) {
+			Assert.assertEquals(onCheckpointCalls, onCheckpointCallCounter);
+			Assert.assertEquals(onCheckpointRolls, onCheckpointRollCounter);
+			Assert.assertEquals(onEventCalls, onEventCallCounter);
+			Assert.assertEquals(onEventRolls, onEventRollCounter);
+			Assert.assertEquals(onProcessingTimeCalls, onProcessingTimeCallCounter);
+			Assert.assertEquals(onProcessingTimeRolls, onProcessingTimeRollCounter);
 		}
 	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/TestUtils.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/TestUtils.java
index bfbc12043ee..9e33064329e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/TestUtils.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/filesystem/TestUtils.java
@@ -20,9 +20,11 @@
 
 import org.apache.flink.api.common.serialization.BulkWriter;
 import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.SimpleVersionedStringSerializer;
 import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.DefaultRollingPolicy;
 import org.apache.flink.streaming.api.operators.StreamSink;
@@ -31,11 +33,17 @@
 import org.apache.commons.io.FileUtils;
 import org.junit.Assert;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
 import java.io.File;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -160,4 +168,109 @@ public String getBucketId(Tuple2<String, Integer> element, Context context) {
 			return SimpleVersionedStringSerializer.INSTANCE;
 		}
 	}
+
+	/**
+	 * A simple {@link BucketAssigner} that accepts {@code String}'s
+	 * and returns the element itself as the bucket id.
+	 */
+	static class StringIdentityBucketAssigner implements BucketAssigner<String, String> {
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public String getBucketId(String element, BucketAssigner.Context context) {
+			return element;
+		}
+
+		@Override
+		public SimpleVersionedSerializer<String> getSerializer() {
+			return SimpleVersionedStringSerializer.INSTANCE;
+		}
+	}
+
+	/**
+	 * A mock {@link SinkFunction.Context} to be used in the tests.
+	 */
+	static class MockSinkContext implements SinkFunction.Context {
+
+		@Nullable
+		private Long elementTimestamp;
+
+		private long watermark;
+
+		private long processingTime;
+
+		MockSinkContext(
+				@Nullable Long elementTimestamp,
+				long watermark,
+				long processingTime) {
+			this.elementTimestamp = elementTimestamp;
+			this.watermark = watermark;
+			this.processingTime = processingTime;
+		}
+
+		@Override
+		public long currentProcessingTime() {
+			return processingTime;
+		}
+
+		@Override
+		public long currentWatermark() {
+			return watermark;
+		}
+
+		@Nullable
+		@Override
+		public Long timestamp() {
+			return elementTimestamp;
+		}
+	}
+
+	/**
+	 * A mock {@link ListState} used for testing the snapshot/restore cycle of the sink.
+	 */
+	static class MockListState<T> implements ListState<T> {
+
+		private final List<T> backingList;
+
+		MockListState() {
+			this.backingList = new ArrayList<>();
+		}
+
+		public List<T> getBackingList() {
+			return backingList;
+		}
+
+		@Override
+		public void update(List<T> values) {
+			backingList.clear();
+			addAll(values);
+		}
+
+		@Override
+		public void addAll(List<T> values) {
+			backingList.addAll(values);
+		}
+
+		@Override
+		public Iterable<T> get() {
+			return new Iterable<T>() {
+
+				@Nonnull
+				@Override
+				public Iterator<T> iterator() {
+					return backingList.iterator();
+				}
+			};
+		}
+
+		@Override
+		public void add(T value) {
+			backingList.add(value);
+		}
+
+		@Override
+		public void clear() {
+			backingList.clear();
+		}
+	}
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message