11"""Utilities for Apache Arrow serialization."""
22
33import logging
4+ import threading
45import os
56from typing import Optional
67
@@ -22,23 +23,26 @@ def __init__(self, file_name: str):
2223 self .new_rows : list [dict ] = []
2324 self .schema : Optional [pa .Schema ] = None # haven't yet learned the schema
2425 self .writer : Optional [pa .RecordBatchStreamWriter ] = None
26+ self ._lock = threading .Condition () # Ensure only one thread writes at a time
2527
2628 def close (self ):
2729 """Close the stream and writes the file as needed."""
28- self ._write ()
29- if self .writer :
30- self .writer .close ()
31- self .sink .close ()
30+ with self ._lock :
31+ self ._write ()
32+ if self .writer :
33+ self .writer .close ()
34+ self .sink .close ()
3235
3336 def set_schema (self , schema : pa .Schema ):
3437 """Set the schema for the file.
3538 Only needed for datasets where we can't learn it from the first record written.
3639
3740 schema (pa.Schema): The schema to use.
3841 """
39- assert self .schema is None
40- self .schema = schema
41- self .writer = pa .ipc .new_stream (self .sink , schema )
42+ with self ._lock :
43+ assert self .schema is None
44+ self .schema = schema
45+ self .writer = pa .ipc .new_stream (self .sink , schema )
4246
4347 def _write (self ):
4448 """Write the new rows to the file."""
@@ -56,9 +60,10 @@ def add_row(self, row_dict: dict):
5660 """Add a row to the arrow file.
5761 We will automatically learn the schema from the first row. But all rows must use that schema.
5862 """
59- self .new_rows .append (row_dict )
60- if len (self .new_rows ) >= chunk_size :
61- self ._write ()
63+ with self ._lock :
64+ self .new_rows .append (row_dict )
65+ if len (self .new_rows ) >= chunk_size :
66+ self ._write ()
6267
6368
6469class FeatherWriter (ArrowWriter ):
0 commit comments