Skip to content

Commit dbf9731

Browse files
committed
allow very large batch imports
1 parent 0a2244c commit dbf9731

3 files changed

Lines changed: 147 additions & 70 deletions

File tree

sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AddFiles.java

Lines changed: 128 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import static org.apache.beam.sdk.io.iceberg.AddFiles.ConvertToDataFile.ERRORS;
2222
import static org.apache.beam.sdk.metrics.Metrics.counter;
2323
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
24-
import static org.apache.beam.sdk.values.PCollection.IsBounded.BOUNDED;
2524
import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
2625
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
2726

@@ -38,14 +37,16 @@
3837
import java.util.List;
3938
import java.util.Map;
4039
import java.util.Objects;
40+
import java.util.UUID;
4141
import java.util.concurrent.Callable;
4242
import java.util.concurrent.ExecutionException;
4343
import java.util.concurrent.ExecutorService;
4444
import java.util.concurrent.Executors;
4545
import java.util.concurrent.Future;
46-
import java.util.concurrent.Semaphore;
4746
import java.util.stream.Collectors;
4847
import java.util.stream.Stream;
48+
import org.apache.beam.sdk.coders.KvCoder;
49+
import org.apache.beam.sdk.coders.VarIntCoder;
4950
import org.apache.beam.sdk.coders.VarLongCoder;
5051
import org.apache.beam.sdk.io.Compression;
5152
import org.apache.beam.sdk.io.FileSystems;
@@ -59,13 +60,13 @@
5960
import org.apache.beam.sdk.state.StateSpecs;
6061
import org.apache.beam.sdk.state.ValueState;
6162
import org.apache.beam.sdk.transforms.DoFn;
62-
import org.apache.beam.sdk.transforms.GroupByKey;
6363
import org.apache.beam.sdk.transforms.GroupIntoBatches;
6464
import org.apache.beam.sdk.transforms.PTransform;
6565
import org.apache.beam.sdk.transforms.ParDo;
6666
import org.apache.beam.sdk.transforms.WithKeys;
6767
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
6868
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
69+
import org.apache.beam.sdk.util.ShardedKey;
6970
import org.apache.beam.sdk.values.KV;
7071
import org.apache.beam.sdk.values.PCollection;
7172
import org.apache.beam.sdk.values.PCollectionRowTuple;
@@ -83,6 +84,10 @@
8384
import org.apache.iceberg.DataFile;
8485
import org.apache.iceberg.DataFiles;
8586
import org.apache.iceberg.FileFormat;
87+
import org.apache.iceberg.GenericManifestFile;
88+
import org.apache.iceberg.ManifestFile;
89+
import org.apache.iceberg.ManifestFiles;
90+
import org.apache.iceberg.ManifestWriter;
8691
import org.apache.iceberg.Metrics;
8792
import org.apache.iceberg.MetricsConfig;
8893
import org.apache.iceberg.PartitionField;
@@ -97,6 +102,7 @@
97102
import org.apache.iceberg.exceptions.NoSuchTableException;
98103
import org.apache.iceberg.io.CloseableIterable;
99104
import org.apache.iceberg.io.InputFile;
105+
import org.apache.iceberg.io.OutputFile;
100106
import org.apache.iceberg.mapping.MappingUtil;
101107
import org.apache.iceberg.mapping.NameMapping;
102108
import org.apache.iceberg.orc.OrcMetrics;
@@ -126,16 +132,20 @@ public class AddFiles extends PTransform<PCollection<String>, PCollectionRowTupl
126132
static final String OUTPUT_TAG = "snapshots";
127133
static final String ERROR_TAG = "errors";
128134
private static final Duration DEFAULT_TRIGGER_INTERVAL = Duration.standardMinutes(10);
129-
private static final Counter numFilesAdded = counter(AddFiles.class, "numFilesAdded");
135+
private static final Counter numManifestFilesAdded =
136+
counter(AddFiles.class, "numManifestFilesAdded");
137+
private static final Counter numDataFilesAdded = counter(AddFiles.class, "numDataFilesAdded");
130138
private static final Counter numErrorFiles = counter(AddFiles.class, "numErrorFiles");
131139
private static final Logger LOG = LoggerFactory.getLogger(AddFiles.class);
132-
private static final int DEFAULT_FILES_TRIGGER = 10_000;
140+
private static final int DEFAULT_DATAFILES_PER_MANIFEST = 10_000;
141+
private static final int DEFAULT_MAX_MANIFESTS_PER_SNAPSHOT = 100;
133142
static final Schema ERROR_SCHEMA =
134143
Schema.builder().addStringField("file").addStringField("error").build();
144+
private static final long MANIFEST_PREFIX = UUID.randomUUID().getMostSignificantBits();
135145
private final IcebergCatalogConfig catalogConfig;
136146
private final String tableIdentifier;
137147
private @Nullable Duration intervalTrigger;
138-
private @Nullable Integer numFilesTrigger;
148+
private final int manifestFileSize;
139149
private final @Nullable String locationPrefix;
140150
private final @Nullable List<String> partitionFields;
141151
private final @Nullable Map<String, String> tableProps;
@@ -146,30 +156,30 @@ public AddFiles(
146156
@Nullable String locationPrefix,
147157
@Nullable List<String> partitionFields,
148158
@Nullable Map<String, String> tableProps,
149-
@Nullable Integer numFilesTrigger,
159+
@Nullable Integer manifestFileSize,
150160
@Nullable Duration intervalTrigger) {
151161
this.catalogConfig = catalogConfig;
152162
this.tableIdentifier = tableIdentifier;
153163
this.partitionFields = partitionFields;
154164
this.tableProps = tableProps;
155165
this.intervalTrigger = intervalTrigger;
156-
this.numFilesTrigger = numFilesTrigger;
166+
this.manifestFileSize =
167+
manifestFileSize != null ? manifestFileSize : DEFAULT_DATAFILES_PER_MANIFEST;
157168
this.locationPrefix = locationPrefix;
158169
}
159170

160171
@Override
161172
public PCollectionRowTuple expand(PCollection<String> input) {
162173
if (input.isBounded().equals(UNBOUNDED)) {
163174
intervalTrigger = intervalTrigger != null ? intervalTrigger : DEFAULT_TRIGGER_INTERVAL;
164-
numFilesTrigger = numFilesTrigger != null ? numFilesTrigger : DEFAULT_FILES_TRIGGER;
165175
LOG.info(
166-
"AddFiles configured to commit after accumulating {} files, or after {} seconds.",
167-
numFilesTrigger,
176+
"AddFiles configured to generate a new manifest after accumulating {} files, or after {} seconds.",
177+
manifestFileSize,
168178
intervalTrigger.getStandardSeconds());
169179
} else {
170180
checkState(
171-
intervalTrigger == null && numFilesTrigger == null,
172-
"Specifying an interval trigger or file trigger is only supported for streaming pipelines.");
181+
intervalTrigger == null,
182+
"Specifying an interval trigger is only supported for streaming pipelines.");
173183
}
174184

175185
if (!Strings.isNullOrEmpty(locationPrefix)) {
@@ -188,32 +198,44 @@ public PCollectionRowTuple expand(PCollection<String> input) {
188198
partitionFields,
189199
tableProps))
190200
.withOutputTags(DATA_FILES, TupleTagList.of(ERRORS)));
191-
SchemaCoder<SerializableDataFile> sdfSchema;
201+
SchemaCoder<SerializableDataFile> sdfCoder;
192202
try {
193-
sdfSchema = SchemaRegistry.createDefault().getSchemaCoder(SerializableDataFile.class);
203+
sdfCoder = SchemaRegistry.createDefault().getSchemaCoder(SerializableDataFile.class);
194204
} catch (Exception e) {
195205
throw new RuntimeException(e);
196206
}
197207

198-
PCollection<KV<Void, SerializableDataFile>> keyedFiles =
208+
PCollection<KV<Integer, SerializableDataFile>> keyedFiles =
199209
dataFiles
200210
.get(DATA_FILES)
201-
.setCoder(sdfSchema)
202-
.apply("AddStaticKey", WithKeys.of((Void) null));
211+
.setCoder(sdfCoder)
212+
.apply("AddSpecIdKey", WithKeys.of(SerializableDataFile::getPartitionSpecId))
213+
.setCoder(KvCoder.of(VarIntCoder.of(), sdfCoder));
214+
215+
GroupIntoBatches<Integer, SerializableDataFile> batchDataFiles =
216+
GroupIntoBatches.ofSize(manifestFileSize);
217+
GroupIntoBatches<String, byte[]> batchManifestFiles =
218+
GroupIntoBatches.ofSize(DEFAULT_MAX_MANIFESTS_PER_SNAPSHOT);
219+
220+
if (keyedFiles.isBounded().equals(UNBOUNDED)) {
221+
batchDataFiles = batchDataFiles.withMaxBufferingDuration(checkStateNotNull(intervalTrigger));
222+
batchManifestFiles =
223+
batchManifestFiles.withMaxBufferingDuration(checkStateNotNull(intervalTrigger));
224+
}
225+
226+
PCollection<KV<ShardedKey<Integer>, Iterable<SerializableDataFile>>> groupedFiles =
227+
keyedFiles.apply("GroupDataFilesIntoBatches", batchDataFiles.withShardedKey());
203228

204-
PCollection<KV<Void, Iterable<SerializableDataFile>>> groupedFiles =
205-
keyedFiles.isBounded().equals(BOUNDED)
206-
? keyedFiles.apply(GroupByKey.create())
207-
: keyedFiles.apply(
208-
GroupIntoBatches.<Void, SerializableDataFile>ofSize(
209-
checkStateNotNull(numFilesTrigger))
210-
.withMaxBufferingDuration(checkStateNotNull(intervalTrigger)));
229+
PCollection<KV<String, byte[]>> manifests =
230+
groupedFiles.apply(
231+
"CreateManifests", ParDo.of(new CreateManifests(catalogConfig, tableIdentifier)));
211232

212233
PCollection<Row> snapshots =
213-
groupedFiles
234+
manifests
235+
.apply("GatherManifests", batchManifestFiles)
214236
.apply(
215-
"CommitFilesToIceberg",
216-
ParDo.of(new CommitFilesDoFn(catalogConfig, tableIdentifier)))
237+
"CommitManifests",
238+
ParDo.of(new CommitManifestFilesDoFn(catalogConfig, tableIdentifier)))
217239
.setRowSchema(SnapshotInfo.getSchema());
218240

219241
return PCollectionRowTuple.of(
@@ -253,9 +275,6 @@ static class ConvertToDataFile extends DoFn<String, SerializableDataFile> {
253275
private transient @MonotonicNonNull LinkedList<Future<ProcessResult>> activeTasks;
254276
private transient volatile @MonotonicNonNull Table table;
255277

256-
// Limit open readers to avoid blowing up memory on one worker
257-
private static final int MAX_READERS = 10;
258-
private static final Semaphore ACTIVE_READERS = new Semaphore(MAX_READERS);
259278
// Number of parallel threads processing incoming files
260279
private static final int THREAD_POOL_SIZE = 10;
261280
private static final int MAX_IN_FLIGHT_TASKS = 100;
@@ -489,7 +508,6 @@ private Callable<ProcessResult> createProcessTask(
489508
.withFileSizeInBytes(inputFile.getLength())
490509
.withPartitionPath(partitionPath)
491510
.build();
492-
493511
return new ProcessResult(
494512
SerializableDataFile.from(df, partitionPath), null, timestamp, window, paneInfo);
495513
};
@@ -658,7 +676,6 @@ static String getPartitionFromMetrics(Metrics metrics, InputFile inputFile, Tabl
658676
Map<Integer, Object> transformedValues = new HashMap<>();
659677

660678
// Do a one-time read of the file and compare all bucket-transformed columns
661-
ACTIVE_READERS.acquire();
662679
try (CloseableIterable<Record> reader = ReadUtils.createReader(inputFile, bucketCols)) {
663680
for (Record record : reader) {
664681
for (Map.Entry<Integer, PartitionField> entry : bucketPartitions.entrySet()) {
@@ -683,8 +700,6 @@ static String getPartitionFromMetrics(Metrics metrics, InputFile inputFile, Tabl
683700
}
684701
}
685702
}
686-
} finally {
687-
ACTIVE_READERS.release();
688703
}
689704

690705
for (Map.Entry<Integer, Object> partitionCol : transformedValues.entrySet()) {
@@ -695,6 +710,64 @@ static String getPartitionFromMetrics(Metrics metrics, InputFile inputFile, Tabl
695710
}
696711
}
697712

713+
/**
714+
* Writes batches of {@link SerializableDataFile}s (grouped by Partition Spec ID) into {@link
715+
* ManifestFile}s.
716+
*
717+
* <p>Returns the byte-encoded {@link ManifestFile}, to be reconstructed and committed by
718+
* downstream {@link CommitManifestFilesDoFn}.
719+
*/
720+
static class CreateManifests
721+
extends DoFn<KV<ShardedKey<Integer>, Iterable<SerializableDataFile>>, KV<String, byte[]>> {
722+
private final IcebergCatalogConfig catalogConfig;
723+
private final String identifier;
724+
private transient @MonotonicNonNull Table table;
725+
726+
public CreateManifests(IcebergCatalogConfig catalogConfig, String identifier) {
727+
this.catalogConfig = catalogConfig;
728+
this.identifier = identifier;
729+
}
730+
731+
@ProcessElement
732+
public void process(
733+
@Element KV<ShardedKey<Integer>, Iterable<SerializableDataFile>> batch,
734+
OutputReceiver<KV<String, byte[]>> output)
735+
throws IOException {
736+
if (!batch.getValue().iterator().hasNext()) {
737+
return;
738+
}
739+
if (table == null) {
740+
table = catalogConfig.catalog().loadTable(TableIdentifier.parse(identifier));
741+
}
742+
743+
PartitionSpec spec = checkStateNotNull(table.specs().get(batch.getKey().getKey()));
744+
745+
String manifestPath =
746+
String.format(
747+
"%s/metadata/%s-%s-m0.avro", table.location(), MANIFEST_PREFIX, UUID.randomUUID());
748+
OutputFile outputFile = table.io().newOutputFile(manifestPath);
749+
750+
int numDataFiles = 0;
751+
ManifestFile manifestFile;
752+
try (ManifestWriter<DataFile> writer = ManifestFiles.write(spec, outputFile)) {
753+
for (SerializableDataFile sdf : batch.getValue()) {
754+
DataFile df = sdf.createDataFile(table.specs());
755+
writer.add(df);
756+
numDataFiles++;
757+
}
758+
writer.close();
759+
manifestFile = writer.toManifestFile();
760+
761+
// Provide a non-null dummy Snapshot ID to avoid encoding/decoding Null exceptions.
762+
// The snapshot ID will be overwritten when the file is committed.
763+
((GenericManifestFile) manifestFile).set(6, -1L);
764+
}
765+
766+
output.output(KV.of(identifier, ManifestFiles.encode(manifestFile)));
767+
numDataFilesAdded.inc(numDataFiles);
768+
}
769+
}
770+
698771
/**
699772
* A stateful {@link DoFn} that commits batches of files to an Iceberg table.
700773
*
@@ -707,13 +780,13 @@ static String getPartitionFromMetrics(Metrics metrics, InputFile inputFile, Tabl
707780
* <li><b>Idempotency:</b> Prevents duplicate commits during bundle failures by calculating a
708781
* deterministic hash for the file set. This ID is stored in the Iceberg {@code Snapshot}
709782
* summary, under the key {@code "beam.add-files-commit-id"}. Before committing, the DoFn
710-
* travereses backwards through recent snapshots to check if the current batch's ID is
783+
* traverses backwards through recent snapshots to check if the current batch's ID is
711784
* already present.
712785
* </ul>
713786
*
714787
* <p>Outputs the resulting Iceberg {@link Snapshot} information.
715788
*/
716-
static class CommitFilesDoFn extends DoFn<KV<Void, Iterable<SerializableDataFile>>, Row> {
789+
static class CommitManifestFilesDoFn extends DoFn<KV<String, Iterable<byte[]>>, Row> {
717790
private final IcebergCatalogConfig catalogConfig;
718791
private final String identifier;
719792
private transient @MonotonicNonNull Table table = null;
@@ -723,17 +796,22 @@ static class CommitFilesDoFn extends DoFn<KV<Void, Iterable<SerializableDataFile
723796
private final StateSpec<ValueState<Long>> lastCommitTimestamp =
724797
StateSpecs.value(VarLongCoder.of());
725798

726-
public CommitFilesDoFn(IcebergCatalogConfig catalogConfig, String identifier) {
799+
public CommitManifestFilesDoFn(IcebergCatalogConfig catalogConfig, String identifier) {
727800
this.catalogConfig = catalogConfig;
728801
this.identifier = identifier;
729802
}
730803

731804
@ProcessElement
732805
public void process(
733-
@Element KV<Void, Iterable<SerializableDataFile>> files,
806+
@Element KV<String, Iterable<byte[]>> batch,
734807
@AlwaysFetched @StateId("lastCommitTimestamp") ValueState<Long> lastCommitTimestamp,
735-
OutputReceiver<Row> output) {
736-
String commitId = commitHash(files.getValue());
808+
OutputReceiver<Row> output)
809+
throws IOException {
810+
List<ManifestFile> manifests = new ArrayList<>();
811+
for (byte[] bytes : batch.getValue()) {
812+
manifests.add(ManifestFiles.decode(bytes));
813+
}
814+
String commitId = commitHash(manifests);
737815
if (table == null) {
738816
table = catalogConfig.catalog().loadTable(TableIdentifier.parse(identifier));
739817
}
@@ -743,30 +821,29 @@ public void process(
743821
return;
744822
}
745823

746-
int numFiles = 0;
824+
int numManifests = 0;
747825
AppendFiles appendFiles = table.newFastAppend();
748-
for (SerializableDataFile file : files.getValue()) {
749-
DataFile df = file.createDataFile(table.specs());
750-
appendFiles.appendFile(df);
751-
numFiles++;
826+
for (ManifestFile file : manifests) {
827+
appendFiles.appendManifest(file);
828+
numManifests++;
752829
}
753830
appendFiles.set(COMMIT_ID_KEY, commitId);
754-
LOG.info("Committing {} files, with commit ID: {}", numFiles, commitId);
831+
LOG.info("Committing {} files, with commit ID: {}", numManifests, commitId);
755832
appendFiles.commit();
756833

757834
Snapshot snapshot = table.currentSnapshot();
758835
output.output(SnapshotInfo.fromSnapshot(snapshot).toRow());
759836
lastCommitTimestamp.write(snapshot.timestampMillis());
760-
numFilesAdded.inc(numFiles);
837+
numManifestFilesAdded.inc(numManifests);
761838
}
762839

763-
private String commitHash(Iterable<SerializableDataFile> files) {
840+
private String commitHash(Iterable<ManifestFile> files) {
764841
Hasher hasher = Hashing.sha256().newHasher();
765842

766843
// Extract, sort, and hash to ensure deterministic output
767844
List<String> paths = new ArrayList<>();
768-
for (SerializableDataFile file : files) {
769-
paths.add(file.getPath());
845+
for (ManifestFile file : files) {
846+
paths.add(file.path());
770847
}
771848
Collections.sort(paths);
772849

0 commit comments

Comments
 (0)