Skip to content

Commit 1ab9316

Browse files
authored
Merge pull request #37682 from stankiewicz/cherrypick-kafka-eos
Cherrypick kafka eos optimization
2 parents 34dfa90 + 15ed9fb commit 1ab9316

4 files changed

Lines changed: 66 additions & 3 deletions

File tree

sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@
5151
import org.apache.beam.sdk.transforms.PTransform;
5252
import org.apache.beam.sdk.transforms.ParDo;
5353
import org.apache.beam.sdk.transforms.SerializableFunction;
54+
import org.apache.beam.sdk.transforms.windowing.AfterFirst;
5455
import org.apache.beam.sdk.transforms.windowing.AfterPane;
56+
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
5557
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
5658
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
59+
import org.apache.beam.sdk.transforms.windowing.Trigger;
5760
import org.apache.beam.sdk.transforms.windowing.Window;
5861
import org.apache.beam.sdk.util.Preconditions;
5962
import org.apache.beam.sdk.values.KV;
@@ -84,6 +87,7 @@
8487
import org.checkerframework.checker.nullness.qual.Nullable;
8588
import org.joda.time.DateTimeUtils;
8689
import org.joda.time.DateTimeZone;
90+
import org.joda.time.Duration;
8791
import org.joda.time.format.DateTimeFormat;
8892
import org.slf4j.Logger;
8993
import org.slf4j.LoggerFactory;
@@ -159,7 +163,8 @@ static void ensureEOSSupport() {
159163
@Override
160164
public PCollection<Void> expand(PCollection<ProducerRecord<K, V>> input) {
161165
String topic = Preconditions.checkStateNotNull(spec.getTopic());
162-
166+
int numElements = spec.getEosTriggerNumElements();
167+
Duration timeout = spec.getEosTriggerTimeout();
163168
int numShards = spec.getNumShards();
164169
if (numShards <= 0) {
165170
try (Consumer<?, ?> consumer = openConsumer(spec)) {
@@ -172,17 +177,34 @@ public PCollection<Void> expand(PCollection<ProducerRecord<K, V>> input) {
172177
}
173178
}
174179
checkState(numShards > 0, "Could not set number of shards");
175-
180+
Trigger.OnceTrigger trigger = null;
181+
if (timeout != null) {
182+
trigger =
183+
AfterFirst.of(
184+
AfterPane.elementCountAtLeast(numElements),
185+
AfterProcessingTime.pastFirstElementInPane().plusDelayOf(timeout));
186+
} else {
187+
// fallback to default
188+
trigger = AfterPane.elementCountAtLeast(numElements);
189+
}
176190
return input
177191
.apply(
178192
Window.<ProducerRecord<K, V>>into(new GlobalWindows()) // Everything into global window.
179-
.triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1)))
193+
.triggering(Repeatedly.forever(trigger))
180194
.discardingFiredPanes())
181195
.apply(
182196
String.format("Shuffle across %d shards", numShards),
183197
ParDo.of(new Reshard<>(numShards)))
184198
.apply("Persist sharding", GroupByKey.create())
185199
.apply("Assign sequential ids", ParDo.of(new Sequencer<>()))
200+
// Reapply the windowing configuration as the continuation trigger doesn't maintain the
201+
// desired batching.
202+
.apply(
203+
"Windowing",
204+
Window.<KV<Integer, KV<Long, TimestampedValue<ProducerRecord<K, V>>>>>into(
205+
new GlobalWindows()) // Everything into global window.
206+
.triggering(Repeatedly.forever(trigger))
207+
.discardingFiredPanes())
186208
.apply("Persist ids", GroupByKey.create())
187209
.apply(
188210
String.format("Write to Kafka topic '%s'", spec.getTopic()),

sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ public static <K, V> WriteRecords<K, V> writeRecords() {
648648
return new AutoValue_KafkaIO_WriteRecords.Builder<K, V>()
649649
.setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES)
650650
.setEOS(false)
651+
.setEosTriggerNumElements(1) // keep default numElements
652+
.setEosTriggerTimeout(null) // keep default trigger (timeout)
651653
.setNumShards(0)
652654
.setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
653655
.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
@@ -3185,6 +3187,10 @@ public abstract static class WriteRecords<K, V>
31853187
@Pure
31863188
public abstract boolean isEOS();
31873189

3190+
public abstract int getEosTriggerNumElements();
3191+
3192+
public abstract @Nullable Duration getEosTriggerTimeout();
3193+
31883194
@Pure
31893195
public abstract @Nullable String getSinkGroupId();
31903196

@@ -3221,6 +3227,10 @@ abstract Builder<K, V> setPublishTimestampFunction(
32213227

32223228
abstract Builder<K, V> setEOS(boolean eosEnabled);
32233229

3230+
abstract Builder<K, V> setEosTriggerNumElements(int numElements);
3231+
3232+
abstract Builder<K, V> setEosTriggerTimeout(@Nullable Duration timeout);
3233+
32243234
abstract Builder<K, V> setSinkGroupId(String sinkGroupId);
32253235

32263236
abstract Builder<K, V> setNumShards(int numShards);
@@ -3368,6 +3378,15 @@ public WriteRecords<K, V> withEOS(int numShards, String sinkGroupId) {
33683378
return toBuilder().setEOS(true).setNumShards(numShards).setSinkGroupId(sinkGroupId).build();
33693379
}
33703380

3381+
public WriteRecords<K, V> withEOSTriggerConfig(int numElements, Duration timeout) {
3382+
checkArgument(numElements >= 1, "numElements should be >= 1");
3383+
checkArgument(timeout != null, "timeout is required for exactly-once sink");
3384+
return toBuilder()
3385+
.setEosTriggerNumElements(numElements)
3386+
.setEosTriggerTimeout(timeout)
3387+
.build();
3388+
}
3389+
33713390
/**
33723391
* When exactly-once semantics are enabled (see {@link #withEOS(int, String)}), the sink needs
33733392
* to fetch previously stored state with Kafka topic. Fetching the metadata requires a consumer.
@@ -3653,6 +3672,19 @@ public Write<K, V> withEOS(int numShards, String sinkGroupId) {
36533672
return withWriteRecordsTransform(getWriteRecordsTransform().withEOS(numShards, sinkGroupId));
36543673
}
36553674

3675+
/**
3676+
* Set the frequency and numElements threshold at which messages are triggered.
3677+
*
3678+
* <p>This is only applicable when the write method is set to EOS.
3679+
*
3680+
* <p>Every timeout duration, or numElements (repeated, after first condition is met) collection
3681+
* of elements written.
3682+
*/
3683+
public Write<K, V> withEOSTriggerConfig(int numElements, Duration timeout) {
3684+
return withWriteRecordsTransform(
3685+
getWriteRecordsTransform().withEOSTriggerConfig(numElements, timeout));
3686+
}
3687+
36563688
/**
36573689
* Wrapper method over {@link WriteRecords#withConsumerFactoryFn(SerializableFunction)}, used to
36583690
* keep the compatibility with old API based on KV type of element.

sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ static class KafkaIOWriteTranslator implements TransformPayloadTranslator<Write<
479479
.addNullableByteArrayField("producer_factory_fn")
480480
.addNullableByteArrayField("publish_timestamp_fn")
481481
.addBooleanField("eos")
482+
.addInt32Field("eos_trigger_num_elements")
483+
.addNullableInt64Field("eos_trigger_timeout_ms")
482484
.addInt32Field("num_shards")
483485
.addNullableStringField("sink_group_id")
484486
.addNullableByteArrayField("consumer_factory_fn")
@@ -547,6 +549,11 @@ public Row toConfigRow(Write<?, ?> transform) {
547549
}
548550

549551
fieldValues.put("eos", writeRecordsTransform.isEOS());
552+
org.joda.time.Duration eosTriggerTimeout = writeRecordsTransform.getEosTriggerTimeout();
553+
if (eosTriggerTimeout != null) {
554+
fieldValues.put("eos_trigger_timeout_ms", eosTriggerTimeout.getMillis());
555+
}
556+
fieldValues.put("eos_trigger_num_elements", writeRecordsTransform.getEosTriggerNumElements());
550557
fieldValues.put("num_shards", writeRecordsTransform.getNumShards());
551558

552559
if (writeRecordsTransform.getSinkGroupId() != null) {

sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ public class KafkaIOTranslationTest {
9494
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getValueSerializer", "value_serializer");
9595
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getPublishTimestampFunction", "publish_timestamp_fn");
9696
WRITE_TRANSFORM_SCHEMA_MAPPING.put("isEOS", "eos");
97+
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerTimeout", "eos_trigger_timeout_ms");
98+
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerNumElements", "eos_trigger_num_elements");
9799
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getSinkGroupId", "sink_group_id");
98100
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getNumShards", "num_shards");
99101
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getConsumerFactoryFn", "consumer_factory_fn");

0 commit comments

Comments
 (0)