diff --git a/rust/lakesoul-io/src/lakesoul_writer.rs b/rust/lakesoul-io/src/lakesoul_writer.rs index b98edfb8a..9d51dcf35 100644 --- a/rust/lakesoul-io/src/lakesoul_writer.rs +++ b/rust/lakesoul-io/src/lakesoul_writer.rs @@ -38,24 +38,10 @@ impl SyncSendableMutableLakeSoulWriter { pub fn try_new(config: LakeSoulIOConfig, runtime: Runtime) -> Result { let runtime = Arc::new(runtime); runtime.clone().block_on(async move { - // if aux sort cols exist, we need to adjust the schema of final writer - // to exclude all aux sort cols - let writer_schema: SchemaRef = if !config.aux_sort_cols.is_empty() { - let schema = config.target_schema.0.clone(); - // O(nm), n = number of target schema fields, m = number of aux sort cols - let proj_indices = schema - .fields - .iter() - .filter(|f| !config.aux_sort_cols.contains(f.name())) - .map(|f| schema.index_of(f.name().as_str()).map_err(DataFusionError::ArrowError)) - .collect::>>()?; - Arc::new(schema.project(proj_indices.borrow())?) - } else { - config.target_schema.0.clone() - }; + let writer_config = config.clone(); let mut config = config.clone(); - let writer = Self::create_writer(writer_schema, writer_config).await?; + let writer = Self::create_writer(writer_config).await?; let schema = writer.schema(); if let Some(mem_limit) = config.mem_limit() { @@ -76,7 +62,7 @@ impl SyncSendableMutableLakeSoulWriter { }) } - async fn create_writer(writer_schema: SchemaRef, config: LakeSoulIOConfig) -> Result> { + async fn create_writer(config: LakeSoulIOConfig) -> Result> { // if aux sort cols exist, we need to adjust the schema of final writer // to exclude all aux sort cols let writer_schema: SchemaRef = if !config.aux_sort_cols.is_empty() { @@ -149,61 +135,60 @@ impl SyncSendableMutableLakeSoulWriter { #[async_recursion::async_recursion(?Send)] async fn write_batch_async(&mut self, record_batch: RecordBatch, do_spill: bool) -> Result<()> { debug!(record_batch_row=?record_batch.num_rows(), do_spill=?do_spill, "write_batch_async"); - let schema = self.schema(); - let config = self.config().clone(); - if let Some(max_file_size) = self.config().max_file_size { - // if max_file_size is set, we need to split batch into multiple files - let in_progress_writer = match &mut self.in_progress { - Some(writer) => writer, - x => - x.insert( - Arc::new(Mutex::new( - Self::create_writer(schema, config).await? - )) - ) - }; - let mut guard = in_progress_writer.lock().await; - - let batch_memory_size = get_batch_memory_size(&record_batch)? as u64; - let batch_rows = record_batch.num_rows() as u64; - // If would exceed max_file_size, split batch - if !do_spill && guard.buffered_size() + batch_memory_size > max_file_size { - let to_write = (batch_rows * (max_file_size - guard.buffered_size())) / batch_memory_size; - if to_write + 1 < batch_rows { - let to_write = to_write as usize + 1; - let a = record_batch.slice(0, to_write); - let b = record_batch.slice(to_write, record_batch.num_rows() - to_write); - drop(guard); - self.write_batch_async(a, true).await?; - return self.write_batch_async(b, false).await; - } - } - let rb_schema = record_batch.schema(); - guard.write_record_batch(record_batch).await.map_err(|e| DataFusionError::Internal(format!("err={}, config={:?}, batch_schema={:?}", e, self.config.clone(), rb_schema)))?; - - if do_spill { - dbg!(format!("spilling writer with size: {}", guard.buffered_size())); + let config = self.config().clone(); + if let Some(max_file_size) = self.config().max_file_size { + // if max_file_size is set, we need to split batch into multiple files + let in_progress_writer = match &mut self.in_progress { + Some(writer) => writer, + x => + x.insert( + Arc::new(Mutex::new( + Self::create_writer(config).await? + )) + ) + }; + let mut guard = in_progress_writer.lock().await; + + let batch_memory_size = get_batch_memory_size(&record_batch)? as u64; + let batch_rows = record_batch.num_rows() as u64; + // If would exceed max_file_size, split batch + if !do_spill && guard.buffered_size() + batch_memory_size > max_file_size { + let to_write = (batch_rows * (max_file_size - guard.buffered_size())) / batch_memory_size; + if to_write + 1 < batch_rows { + let to_write = to_write as usize + 1; + let a = record_batch.slice(0, to_write); + let b = record_batch.slice(to_write, record_batch.num_rows() - to_write); drop(guard); - if let Some(writer) = self.in_progress.take() { - let inner_writer = match Arc::try_unwrap(writer) { - Ok(inner) => inner, - Err(_) => { - return Err(DataFusionError::Internal("Cannot get ownership of inner writer".to_string())) - }, - }; - let writer = inner_writer.into_inner(); - let results = writer.flush_and_close().await.map_err(|e| DataFusionError::Internal(format!("err={}, config={:?}, batch_schema={:?}", e, self.config.clone(), rb_schema)))?; - self.flush_results.extend(results); - } + self.write_batch_async(a, true).await?; + return self.write_batch_async(b, false).await; + } + } + let rb_schema = record_batch.schema(); + guard.write_record_batch(record_batch).await.map_err(|e| DataFusionError::Internal(format!("err={}, config={:?}, batch_schema={:?}", e, self.config.clone(), rb_schema)))?; + + if do_spill { + dbg!(format!("spilling writer with size: {}", guard.buffered_size())); + drop(guard); + if let Some(writer) = self.in_progress.take() { + let inner_writer = match Arc::try_unwrap(writer) { + Ok(inner) => inner, + Err(_) => { + return Err(DataFusionError::Internal("Cannot get ownership of inner writer".to_string())) + }, + }; + let writer = inner_writer.into_inner(); + let results = writer.flush_and_close().await.map_err(|e| DataFusionError::Internal(format!("err={}, config={:?}, batch_schema={:?}", e, self.config.clone(), rb_schema)))?; + self.flush_results.extend(results); } - Ok(()) - } else if let Some(inner_writer) = &self.in_progress { - let inner_writer = inner_writer.clone(); - let mut writer = inner_writer.lock().await; - writer.write_record_batch(record_batch).await - } else { - Err(DataFusionError::Internal("Invalid state of inner writer".to_string())) } + Ok(()) + } else if let Some(inner_writer) = &self.in_progress { + let inner_writer = inner_writer.clone(); + let mut writer = inner_writer.lock().await; + writer.write_record_batch(record_batch).await + } else { + Err(DataFusionError::Internal("Invalid state of inner writer".to_string())) + } }