Skip to content

Commit 630f32a

Browse files
authored
[Drain] Expose drain to dofn processElement and onTimer (#37825)
1 parent d6bc507 commit 630f32a

8 files changed

Lines changed: 124 additions & 2 deletions

File tree

runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,11 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
555555
return timestamp();
556556
}
557557

558+
@Override
559+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
560+
return elem.causedByDrain();
561+
}
562+
558563
@Override
559564
public String timerId(DoFn<InputT, OutputT> doFn) {
560565
throw new UnsupportedOperationException(
@@ -831,6 +836,11 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
831836
return timestamp();
832837
}
833838

839+
@Override
840+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
841+
return causedByDrain;
842+
}
843+
834844
@Override
835845
public String timerId(DoFn<InputT, OutputT> doFn) {
836846
return timerId;
@@ -1119,6 +1129,11 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
11191129
return timestamp;
11201130
}
11211131

1132+
@Override
1133+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
1134+
throw new UnsupportedOperationException("CausedByDrain parameters are not supported.");
1135+
}
1136+
11221137
@Override
11231138
public String timerId(DoFn<InputT, OutputT> doFn) {
11241139
throw new UnsupportedOperationException("Timer parameters are not supported.");

sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod;
7777
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter;
7878
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
79+
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.CausedByDrainParameter;
7980
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
8081
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter;
8182
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OnTimerContextParameter;
@@ -126,6 +127,7 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
126127
public static final String ELEMENT_PARAMETER_METHOD = "element";
127128
public static final String SCHEMA_ELEMENT_PARAMETER_METHOD = "schemaElement";
128129
public static final String TIMESTAMP_PARAMETER_METHOD = "timestamp";
130+
public static final String CAUSED_BY_DRAIN_PARAMETER_METHOD = "causedByDrain";
129131
public static final String BUNDLE_FINALIZER_PARAMETER_METHOD = "bundleFinalizer";
130132
public static final String OUTPUT_ROW_RECEIVER_METHOD = "outputRowReceiver";
131133
public static final String TIME_DOMAIN_PARAMETER_METHOD = "timeDomain";
@@ -1100,6 +1102,15 @@ public StackManipulation dispatch(TimestampParameter p) {
11001102
TIMESTAMP_PARAMETER_METHOD, DoFn.class)));
11011103
}
11021104

1105+
@Override
1106+
public StackManipulation dispatch(CausedByDrainParameter p) {
1107+
return new StackManipulation.Compound(
1108+
pushDelegate,
1109+
MethodInvocation.invoke(
1110+
getExtraContextFactoryMethodDescription(
1111+
CAUSED_BY_DRAIN_PARAMETER_METHOD, DoFn.class)));
1112+
}
1113+
11031114
@Override
11041115
public StackManipulation dispatch(BundleFinalizerParameter p) {
11051116
return simpleExtraContextParameter(BUNDLE_FINALIZER_PARAMETER_METHOD);

sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
4242
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
4343
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
44+
import org.apache.beam.sdk.values.CausedByDrain;
4445
import org.apache.beam.sdk.values.Row;
4546
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
4647
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -217,6 +218,9 @@ interface ArgumentProvider<InputT, OutputT> {
217218
/** Provide a reference to the input element timestamp. */
218219
Instant timestamp(DoFn<InputT, OutputT> doFn);
219220

221+
/** Provide a reference to the caused by drain. */
222+
CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn);
223+
220224
/** Provide a reference to the time domain for a timer firing. */
221225
TimeDomain timeDomain(DoFn<InputT, OutputT> doFn);
222226

@@ -325,6 +329,12 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
325329
String.format("Timestamp unsupported in %s", getErrorContext()));
326330
}
327331

332+
@Override
333+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
334+
throw new UnsupportedOperationException(
335+
String.format("CausedByDrain unsupported in %s", getErrorContext()));
336+
}
337+
328338
@Override
329339
public String timerId(DoFn<InputT, OutputT> doFn) {
330340
throw new UnsupportedOperationException(
@@ -514,6 +524,11 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
514524
return delegate.timestamp(doFn);
515525
}
516526

527+
@Override
528+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
529+
return delegate.causedByDrain(doFn);
530+
}
531+
517532
@Override
518533
public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
519534
return delegate.timeDomain(doFn);

sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ public <ResultT> ResultT match(Cases<ResultT> cases) {
342342
return cases.dispatch((TimerIdParameter) this);
343343
} else if (this instanceof BundleFinalizerParameter) {
344344
return cases.dispatch((BundleFinalizerParameter) this);
345+
} else if (this instanceof CausedByDrainParameter) {
346+
return cases.dispatch((CausedByDrainParameter) this);
345347
} else if (this instanceof KeyParameter) {
346348
return cases.dispatch((KeyParameter) this);
347349
} else {
@@ -400,6 +402,8 @@ public interface Cases<ResultT> {
400402

401403
ResultT dispatch(BundleFinalizerParameter p);
402404

405+
ResultT dispatch(CausedByDrainParameter p);
406+
403407
ResultT dispatch(KeyParameter p);
404408

405409
/** A base class for a visitor with a default method for cases it is not interested in. */
@@ -497,6 +501,11 @@ public ResultT dispatch(BundleFinalizerParameter p) {
497501
return dispatchDefault(p);
498502
}
499503

504+
@Override
505+
public ResultT dispatch(CausedByDrainParameter p) {
506+
return dispatchDefault(p);
507+
}
508+
500509
@Override
501510
public ResultT dispatch(StateParameter p) {
502511
return dispatchDefault(p);
@@ -552,6 +561,8 @@ public ResultT dispatch(KeyParameter p) {
552561
new AutoValue_DoFnSignature_Parameter_PipelineOptionsParameter();
553562
private static final BundleFinalizerParameter BUNDLE_FINALIZER_PARAMETER =
554563
new AutoValue_DoFnSignature_Parameter_BundleFinalizerParameter();
564+
private static final CausedByDrainParameter CAUSED_BY_DRAIN_PARAMETER =
565+
new AutoValue_DoFnSignature_Parameter_CausedByDrainParameter();
555566
private static final OnWindowExpirationContextParameter ON_WINDOW_EXPIRATION_CONTEXT_PARAMETER =
556567
new AutoValue_DoFnSignature_Parameter_OnWindowExpirationContextParameter();
557568

@@ -575,6 +586,11 @@ public static BundleFinalizerParameter bundleFinalizer() {
575586
return BUNDLE_FINALIZER_PARAMETER;
576587
}
577588

589+
/** Returns a {@link CausedByDrainParameter}. */
590+
public static CausedByDrainParameter causedByDrainParameter() {
591+
return CAUSED_BY_DRAIN_PARAMETER;
592+
}
593+
578594
public static ElementParameter elementParameter(TypeDescriptor<?> elementT) {
579595
return new AutoValue_DoFnSignature_Parameter_ElementParameter(elementT);
580596
}
@@ -727,6 +743,16 @@ public abstract static class BundleFinalizerParameter extends Parameter {
727743
BundleFinalizerParameter() {}
728744
}
729745

746+
/**
747+
* Descriptor for a {@link Parameter} of type {@link org.apache.beam.sdk.values.CausedByDrain}.
748+
*
749+
* <p>All such descriptors are equal.
750+
*/
751+
@AutoValue
752+
public abstract static class CausedByDrainParameter extends Parameter {
753+
CausedByDrainParameter() {}
754+
}
755+
730756
/**
731757
* Descriptor for a {@link Parameter} of type {@link DoFn.Element}.
732758
*

sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
9292
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
9393
import org.apache.beam.sdk.util.common.ReflectHelpers;
94+
import org.apache.beam.sdk.values.CausedByDrain;
9495
import org.apache.beam.sdk.values.KV;
9596
import org.apache.beam.sdk.values.PCollection;
9697
import org.apache.beam.sdk.values.Row;
@@ -139,6 +140,7 @@ private DoFnSignatures() {}
139140
Parameter.StateParameter.class,
140141
Parameter.SideInputParameter.class,
141142
Parameter.TimerFamilyParameter.class,
143+
Parameter.CausedByDrainParameter.class,
142144
Parameter.BundleFinalizerParameter.class);
143145

144146
private static final ImmutableList<Class<? extends Parameter>>
@@ -155,6 +157,7 @@ private DoFnSignatures() {}
155157
Parameter.RestrictionTrackerParameter.class,
156158
Parameter.WatermarkEstimatorParameter.class,
157159
Parameter.SideInputParameter.class,
160+
Parameter.CausedByDrainParameter.class,
158161
Parameter.BundleFinalizerParameter.class);
159162

160163
private static final ImmutableList<Class<? extends Parameter>> ALLOWED_SETUP_PARAMETERS =
@@ -185,6 +188,7 @@ private DoFnSignatures() {}
185188
Parameter.StateParameter.class,
186189
Parameter.TimerFamilyParameter.class,
187190
Parameter.TimerIdParameter.class,
191+
Parameter.CausedByDrainParameter.class,
188192
Parameter.KeyParameter.class);
189193

190194
private static final ImmutableList<Class<? extends Parameter>>
@@ -201,6 +205,7 @@ private DoFnSignatures() {}
201205
Parameter.StateParameter.class,
202206
Parameter.TimerFamilyParameter.class,
203207
Parameter.TimerIdParameter.class,
208+
Parameter.CausedByDrainParameter.class,
204209
Parameter.KeyParameter.class);
205210

206211
private static final Collection<Class<? extends Parameter>>
@@ -1357,6 +1362,11 @@ private static Parameter analyzeExtraParameter(
13571362
return Parameter.keyT(paramT);
13581363
} else if (rawType.equals(TimeDomain.class)) {
13591364
return Parameter.timeDomainParameter();
1365+
} else if (CausedByDrain.class.isAssignableFrom(rawType)) {
1366+
methodErrors.checkArgument(
1367+
rawType.equals(CausedByDrain.class),
1368+
"CausedByDrain argument must have type org.apache.beam.sdk.values.CausedByDrain.");
1369+
return Parameter.causedByDrainParameter();
13601370
} else if (hasAnnotation(DoFn.SideInput.class, param.getAnnotations())) {
13611371
String sideInputId = getSideInputId(param.getAnnotations());
13621372
paramErrors.checkArgument(

sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,11 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
543543
return outerContext.timestamp();
544544
}
545545

546+
@Override
547+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
548+
return outerContext.causedByDrain();
549+
}
550+
546551
@Override
547552
public String timerId(DoFn<InputT, OutputT> doFn) {
548553
throw new UnsupportedOperationException();

sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.beam.sdk.transforms.Sum;
5757
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
5858
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter;
59+
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.CausedByDrainParameter;
5960
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
6061
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter;
6162
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter;
@@ -78,6 +79,7 @@
7879
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
7980
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
8081
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
82+
import org.apache.beam.sdk.values.CausedByDrain;
8183
import org.apache.beam.sdk.values.KV;
8284
import org.apache.beam.sdk.values.Row;
8385
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -130,10 +132,11 @@ public void process(
130132
PipelineOptions options,
131133
@SideInput("tag1") String input1,
132134
@SideInput("tag2") Integer input2,
133-
BundleFinalizer bundleFinalizer) {}
135+
BundleFinalizer bundleFinalizer,
136+
CausedByDrain causedByDrain) {}
134137
}.getClass());
135138

136-
assertThat(sig.processElement().extraParameters().size(), equalTo(9));
139+
assertThat(sig.processElement().extraParameters().size(), equalTo(10));
137140
assertThat(sig.processElement().extraParameters().get(0), instanceOf(ElementParameter.class));
138141
assertThat(sig.processElement().extraParameters().get(1), instanceOf(TimestampParameter.class));
139142
assertThat(sig.processElement().extraParameters().get(2), instanceOf(WindowParameter.class));
@@ -146,6 +149,8 @@ public void process(
146149
assertThat(sig.processElement().extraParameters().get(7), instanceOf(SideInputParameter.class));
147150
assertThat(
148151
sig.processElement().extraParameters().get(8), instanceOf(BundleFinalizerParameter.class));
152+
assertThat(
153+
sig.processElement().extraParameters().get(9), instanceOf(CausedByDrainParameter.class));
149154
}
150155

151156
@Test
@@ -585,6 +590,31 @@ public void onTimer(BoundedWindow w) {}
585590
instanceOf(WindowParameter.class));
586591
}
587592

593+
@Test
594+
public void testCausedByDrainOnTimer() throws Exception {
595+
final String timerId = "some-timer-id";
596+
final String timerDeclarationId = TimerDeclaration.PREFIX + timerId;
597+
598+
DoFnSignature sig =
599+
DoFnSignatures.getSignature(
600+
new DoFn<String, String>() {
601+
602+
@TimerId(timerId)
603+
private final TimerSpec myfield1 = TimerSpecs.timer(TimeDomain.EVENT_TIME);
604+
605+
@ProcessElement
606+
public void process(ProcessContext c) {}
607+
608+
@OnTimer(timerId)
609+
public void onTimer(CausedByDrain causedByDrain) {}
610+
}.getClass());
611+
612+
assertThat(sig.onTimerMethods().get(timerDeclarationId).extraParameters().size(), equalTo(1));
613+
assertThat(
614+
sig.onTimerMethods().get(timerDeclarationId).extraParameters().get(0),
615+
instanceOf(CausedByDrainParameter.class));
616+
}
617+
588618
@Test
589619
public void testAllParamsOnTimer() throws Exception {
590620
final String timerId = "some-timer-id";

sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,6 +1804,11 @@ public <T> void outputWindowedValue(
18041804
outputTo(consumer, WindowedValues.of(output, timestamp, windows, paneInfo));
18051805
}
18061806

1807+
@Override
1808+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
1809+
return currentElement.causedByDrain();
1810+
}
1811+
18071812
@Override
18081813
public State state(String stateId, boolean alwaysFetched) {
18091814
StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId);
@@ -1946,6 +1951,11 @@ public <T> void outputWindowedValue(
19461951
public CausedByDrain causedByDrain() {
19471952
return currentElement.causedByDrain();
19481953
}
1954+
1955+
@Override
1956+
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
1957+
return currentElement.causedByDrain();
1958+
}
19491959
}
19501960

19511961
/** Provides base arguments for a {@link DoFnInvoker} for a non-window observing method. */

0 commit comments

Comments
 (0)