From 4ba74f79bd0409a0cc5ca9b8e1555db72f5ed417 Mon Sep 17 00:00:00 2001 From: Yuya Shiraki Date: Sun, 4 Sep 2022 01:45:58 -0700 Subject: [PATCH] split S3 files into smaller files to send large union file Summary: # Context We found that AWS-SDK S3 API would fail when we try to write more than 5GB of data. It is a blocking us to do capacity testing for a larger FARGATE container. In this diff, as mentioned in [the post](https://fb.workplace.com/groups/pidmatchingxfn/posts/493743615908631), we are splitting union file based on number of rows. # Description We have made following changes. - Added new arg `s3api_max_rows` in the private-id-multi-key-client and private-id-multi-key-server binaries. We will use this to split a file for S3 upload. - Added an optional arg `num_split` in save_id_map() and writer_helper(). When `num_split` is specified, it would use the arg `path` as its prefix and save files in `{path}_0`, `{path}_1`, etc. - In rpc_server.rs and client.rs, calculates the num_split based on s3api_max_rows, and passes the num_split arg for S3 only. Then, for each split file, it calls copy_from_local(). Differential Revision: D39219674 fbshipit-source-id: 48b9c060853a5cce4b8d8d6aa4d5fc2b744f208f --- .../src/rpc/private-id-multi-key/client.rs | 30 +++++++--- .../rpc/private-id-multi-key/rpc_server.rs | 26 +++++--- .../src/rpc/private-id-multi-key/server.rs | 8 +++ protocol/src/private_id_multi_key/company.rs | 14 +++-- protocol/src/private_id_multi_key/mod.rs | 59 ++++++++++++------- protocol/src/private_id_multi_key/partner.rs | 15 +++-- protocol/src/private_id_multi_key/traits.rs | 6 +- 7 files changed, 114 insertions(+), 44 deletions(-) diff --git a/protocol-rpc/src/rpc/private-id-multi-key/client.rs b/protocol-rpc/src/rpc/private-id-multi-key/client.rs index 5d611b2..a0f3ce0 100644 --- a/protocol-rpc/src/rpc/private-id-multi-key/client.rs +++ b/protocol-rpc/src/rpc/private-id-multi-key/client.rs @@ -100,6 +100,11 @@ async fn main() -> Result<(), Box> { .long("run_id") .default_value("") .help("A run_id used to identify all the logs in a PL/PA run."), + Arg::with_name("s3api_max_rows") + .long("s3api_max_rows") + .takes_value(true) + .default_value("5000000") + .help("Number of rows per each output S3 file to split."), ]) .groups(&[ ArgGroup::with_name("tls") @@ -114,6 +119,8 @@ async fn main() -> Result<(), Box> { let global_timer = timer::Timer::new_silent("global"); let input_path_str = matches.value_of("input").unwrap_or("input.csv"); let mut input_path = input_path_str.to_string(); + let s3api_max_rows_str = matches.value_of("s3api_max_rows").unwrap_or("5000000"); + let s3_api_max_rows: usize = s3api_max_rows_str.to_string().parse().unwrap(); if let Ok(s3_path) = S3Path::from_str(input_path_str) { info!( "Reading {} from S3 and copying to local path", @@ -358,27 +365,36 @@ async fn main() -> Result<(), Box> { let s3_tempfile = tempfile::NamedTempFile::new().unwrap(); let (_file, path) = s3_tempfile.keep().unwrap(); let path = path.to_str().expect("Failed to convert path to str"); + let num_split = ((partner_protocol.get_id_map_size() as f32) + / (s3_api_max_rows as f32)) + .ceil() as usize; partner_protocol - .save_id_map(&String::from(path)) + .save_id_map(&String::from(path), Some(num_split)) .expect("Failed to save id map to tempfile"); - output_path_s3 - .copy_from_local(&path) - .await - .expect("Failed to write to S3"); + for n in 0..num_split { + let chunk_path = format!("{}_{}", path, n); + output_path_s3 + .copy_from_local(&chunk_path) + .await + .expect("Failed to write to S3"); + } } else if let Ok(output_path_gcp) = GCSPath::from_str(p) { let gcs_tempfile = tempfile::NamedTempFile::new().unwrap(); let (_file, path) = gcs_tempfile.keep().unwrap(); let path = path.to_str().expect("Failed to convert path to str"); partner_protocol - .save_id_map(&String::from(path)) + .save_id_map(&String::from(path), None) .expect("Failed to save id map to tempfile"); output_path_gcp .copy_from_local(&path) .await .expect("Failed to write to GCS"); } else { + let num_split = ((partner_protocol.get_id_map_size() as f32) + / (s3_api_max_rows as f32)) + .ceil() as usize; partner_protocol - .save_id_map(&String::from(p)) + .save_id_map(&String::from(p), Some(num_split)) .expect("Failed to save id map to output file"); } } diff --git a/protocol-rpc/src/rpc/private-id-multi-key/rpc_server.rs b/protocol-rpc/src/rpc/private-id-multi-key/rpc_server.rs index fd2235a..6144f14 100644 --- a/protocol-rpc/src/rpc/private-id-multi-key/rpc_server.rs +++ b/protocol-rpc/src/rpc/private-id-multi-key/rpc_server.rs @@ -43,6 +43,7 @@ pub struct PrivateIdMultiKeyService { input_with_headers: bool, metrics_path: Option, metrics_obj: metrics::Metrics, + s3_api_max_rows: usize, pub killswitch: Arc, } @@ -52,6 +53,7 @@ impl PrivateIdMultiKeyService { output_path: Option<&str>, input_with_headers: bool, metrics_path: Option, + s3_api_max_rows: usize, ) -> PrivateIdMultiKeyService { PrivateIdMultiKeyService { protocol: CompanyPrivateIdMultiKey::new(), @@ -60,6 +62,7 @@ impl PrivateIdMultiKeyService { input_with_headers, metrics_path, metrics_obj: metrics::Metrics::new("private-id-multi-key".to_string()), + s3_api_max_rows, killswitch: Arc::new(AtomicBool::new(false)), } } @@ -298,26 +301,35 @@ impl PrivateIdMultiKey for PrivateIdMultiKeyService { let s3_tempfile = tempfile::NamedTempFile::new().unwrap(); let (_file, path) = s3_tempfile.keep().unwrap(); let path = path.to_str().expect("Failed to convert path to str"); + let num_split = ((self.protocol.get_id_map_size() as f32) + / (self.s3_api_max_rows as f32)) + .ceil() as usize; self.protocol - .save_id_map(&String::from(path)) + .save_id_map(&String::from(path), Some(num_split)) .expect("Failed to save id map to tempfile"); - output_path_s3 - .copy_from_local(&path) - .await - .expect("Failed to write to S3"); + for n in 0..num_split { + let chunk_path = format!("{}_{}", path, n); + output_path_s3 + .copy_from_local(&chunk_path) + .await + .expect("Failed to write to S3"); + } } else if let Ok(output_path_gcp) = GCSPath::from_str(p) { let gcs_tempfile = tempfile::NamedTempFile::new().unwrap(); let (_file, path) = gcs_tempfile.keep().unwrap(); let path = path.to_str().expect("Failed to convert path to str"); self.protocol - .save_id_map(&String::from(path)) + .save_id_map(&String::from(path), None) .expect("Failed to save id map to tempfile"); output_path_gcp .copy_from_local(&path) .await .expect("Failed to write to GCS"); } else { - self.protocol.save_id_map(p).unwrap(); + let num_split = ((self.protocol.get_id_map_size() as f32) + / (self.s3_api_max_rows as f32)) + .ceil() as usize; + self.protocol.save_id_map(p, Some(num_split)).unwrap(); } } None => self.protocol.print_id_map(), diff --git a/protocol-rpc/src/rpc/private-id-multi-key/server.rs b/protocol-rpc/src/rpc/private-id-multi-key/server.rs index ecd9bfd..a8b1a41 100644 --- a/protocol-rpc/src/rpc/private-id-multi-key/server.rs +++ b/protocol-rpc/src/rpc/private-id-multi-key/server.rs @@ -99,6 +99,11 @@ async fn main() -> Result<(), Box> { .long("run_id") .default_value("") .help("A run_id used to identify all the logs in a PL/PA run."), + Arg::with_name("s3api_max_rows") + .long("s3api_max_rows") + .takes_value(true) + .default_value("5000000") + .help("Number of rows per each output S3 file to split."), ]) .groups(&[ ArgGroup::with_name("tls") @@ -129,6 +134,8 @@ async fn main() -> Result<(), Box> { let input_with_headers = matches.is_present("input-with-headers"); let output_path = matches.value_of("output"); let metric_path = matches.value_of("metric-path"); + let s3api_max_rows_str = matches.value_of("s3api_max_rows").unwrap_or("5000000"); + let s3_api_max_rows: usize = s3api_max_rows_str.to_string().parse().unwrap(); let no_tls = matches.is_present("no-tls"); let host = matches.value_of("host"); @@ -167,6 +174,7 @@ async fn main() -> Result<(), Box> { output_path, input_with_headers, metrics_output_path, + s3_api_max_rows, ); let ks = service.killswitch.clone(); diff --git a/protocol/src/private_id_multi_key/company.rs b/protocol/src/private_id_multi_key/company.rs index c223ece..224467b 100644 --- a/protocol/src/private_id_multi_key/company.rs +++ b/protocol/src/private_id_multi_key/company.rs @@ -5,7 +5,6 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::RwLock; -use common::files; use common::permutations::gen_permute_pattern; use common::permutations::permute; use common::permutations::undo_permute; @@ -481,16 +480,16 @@ impl CompanyPrivateIdMultiKeyProtocol for CompanyPrivateIdMultiKey { fn print_id_map(&self) { match (self.plaintext.clone().read(), self.id_map.clone().read()) { (Ok(data), Ok(id_map)) => { - writer_helper(&data, &id_map, None); + writer_helper(&data, &id_map, None, None); } _ => panic!("Cannot print id_map"), } } - fn save_id_map(&self, path: &str) -> Result<(), ProtocolError> { + fn save_id_map(&self, path: &str, num_split: Option) -> Result<(), ProtocolError> { match (self.plaintext.clone().read(), self.id_map.clone().read()) { (Ok(data), Ok(id_map)) => { - writer_helper(&data, &id_map, Some(path.to_string())); + writer_helper(&data, &id_map, Some(path.to_string()), num_split); Ok(()) } _ => Err(ProtocolError::ErrorIO( @@ -498,6 +497,13 @@ impl CompanyPrivateIdMultiKeyProtocol for CompanyPrivateIdMultiKey { )), } } + + fn get_id_map_size(&self) -> usize { + match self.id_map.clone().read() { + Ok(id_map) => id_map.len(), + _ => panic!("Cannot get id_map size"), + } + } } #[cfg(test)] diff --git a/protocol/src/private_id_multi_key/mod.rs b/protocol/src/private_id_multi_key/mod.rs index 850e717..0820981 100644 --- a/protocol/src/private_id_multi_key/mod.rs +++ b/protocol/src/private_id_multi_key/mod.rs @@ -62,20 +62,41 @@ fn load_data(plaintext: Arc>>>, path: &str, input_with_he t.qps("text read", text_len); } -fn writer_helper(data: &[Vec], id_map: &[(String, usize, bool)], path: Option) { - let mut device = match path { - Some(path) => { - let wr = csv::WriterBuilder::new() - .flexible(true) - .buffer_capacity(1024) - .from_path(path) - .unwrap(); - Some(wr) - } - None => None, - }; +fn writer_helper( + data: &[Vec], + id_map: &[(String, usize, bool)], + path: Option, + num_split: Option, +) { + let mut device_list = Vec::new(); + let mut chunk_size = id_map.len(); + match path { + Some(path) => match num_split { + Some(num_split) => { + for n in 0..num_split { + let chunk_path = format!("{}_{}", path, n); + let wr = csv::WriterBuilder::new() + .flexible(true) + .buffer_capacity(1024) + .from_path(chunk_path) + .unwrap(); + device_list.push(wr); + chunk_size = ((id_map.len() as f32) / (num_split as f32)).ceil() as usize; + } + } + None => { + let wr = csv::WriterBuilder::new() + .flexible(true) + .buffer_capacity(1024) + .from_path(path) + .unwrap(); + device_list.push(wr); + } + }, + None => (), + } - for (key, idx, flag) in id_map.iter() { + for (pos, (key, idx, flag)) in id_map.iter().enumerate() { let mut v = vec![(*key).clone()]; match flag { @@ -83,13 +104,11 @@ fn writer_helper(data: &[Vec], id_map: &[(String, usize, bool)], path: O false => v.push("NA".to_string()), } - match device { - Some(ref mut wr) => { - wr.write_record(v.as_slice()).unwrap(); - } - None => { - println!("{}", v.join(",")); - } + if device_list.is_empty() { + println!("{}", v.join(",")); + } else { + let device = &mut device_list[pos / chunk_size]; + device.write_record(v.as_slice()).unwrap(); } } } diff --git a/protocol/src/private_id_multi_key/partner.rs b/protocol/src/private_id_multi_key/partner.rs index f1ca921..04d1ae8 100644 --- a/protocol/src/private_id_multi_key/partner.rs +++ b/protocol/src/private_id_multi_key/partner.rs @@ -236,16 +236,16 @@ impl PartnerPrivateIdMultiKeyProtocol for PartnerPrivateIdMultiKey { fn print_id_map(&self) { match (self.plaintext.clone().read(), self.id_map.clone().read()) { (Ok(data), Ok(id_map)) => { - writer_helper(&data, &id_map, None); + writer_helper(&data, &id_map, None, None); } _ => panic!("Cannot print id_map"), } } - fn save_id_map(&self, path: &str) -> Result<(), ProtocolError> { + fn save_id_map(&self, path: &str, num_split: Option) -> Result<(), ProtocolError> { match (self.plaintext.clone().read(), self.id_map.clone().read()) { (Ok(data), Ok(id_map)) => { - writer_helper(&data, &id_map, Some(path.to_string())); + writer_helper(&data, &id_map, Some(path.to_string()), num_split); Ok(()) } _ => Err(ProtocolError::ErrorIO( @@ -253,6 +253,13 @@ impl PartnerPrivateIdMultiKeyProtocol for PartnerPrivateIdMultiKey { )), } } + + fn get_id_map_size(&self) -> usize { + match self.id_map.clone().read() { + Ok(id_map) => id_map.len(), + _ => panic!("Cannot get id_map size"), + } + } } #[cfg(test)] @@ -454,7 +461,7 @@ mod tests { // Create a file inside of `std::env::temp_dir()`. let mut file1 = NamedTempFile::new().unwrap(); let p = file1.path().to_str().unwrap(); - partner.save_id_map(p).unwrap(); + partner.save_id_map(p, None).unwrap(); let mut actual_result = String::new(); file1.read_to_string(&mut actual_result).unwrap(); let expected_result = "08FCF66A09440EFCB475BBCFA5915648A9A7DD0F2D0B75E965EDBAEC249D7D,NA\n30A397CD5C79AB7D6FBD59BF191326BAC43983497C81E1E2F109B3252EACE5F,email1,phone1\n7E105B924F454CF6E0BB4DC7158003A5647DC64A08FDC58BFCC03BDFF85718,email3\nD69F32E652AED8427DAACF74D57B807714160D7454310BF3515DD5AA5F98F4F,phone2\n"; diff --git a/protocol/src/private_id_multi_key/traits.rs b/protocol/src/private_id_multi_key/traits.rs index e3b400f..b8ef5ff 100644 --- a/protocol/src/private_id_multi_key/traits.rs +++ b/protocol/src/private_id_multi_key/traits.rs @@ -13,7 +13,8 @@ pub trait PartnerPrivateIdMultiKeyProtocol { fn create_id_map(&self, partner: TPayload, company: TPayload); fn print_id_map(&self); - fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; + fn save_id_map(&self, path: &str, num_split: Option) -> Result<(), ProtocolError>; + fn get_id_map_size(&self) -> usize; } pub trait CompanyPrivateIdMultiKeyProtocol { @@ -38,5 +39,6 @@ pub trait CompanyPrivateIdMultiKeyProtocol { fn write_company_to_id_map(&self); fn print_id_map(&self); - fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; + fn save_id_map(&self, path: &str, num_split: Option) -> Result<(), ProtocolError>; + fn get_id_map_size(&self) -> usize; }