diff --git a/src/types/array.rs b/src/types/array.rs index 269e502023..089a792b6f 100644 --- a/src/types/array.rs +++ b/src/types/array.rs @@ -610,6 +610,159 @@ impl<'data> JanetArray<'data> { ret } + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns `false`. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + /// + /// # Examples + /// + /// ``` + /// let mut vec = vec![1, 2, 3, 4]; + /// vec.retain(|&x| x % 2 == 0); + /// assert_eq!(vec, [2, 4]); + /// ``` + /// + /// Because the elements are visited exactly once in the original order, + /// external state may be used to decide which elements to keep. + /// + /// ``` + /// let mut vec = vec![1, 2, 3, 4, 5]; + /// let keep = [false, true, true, false, true]; + /// let mut iter = keep.iter(); + /// vec.retain(|_| *iter.next().unwrap()); + /// assert_eq!(vec, [2, 3, 5]); + /// ``` + pub fn retain(&mut self, mut f: F) + where F: FnMut(&Janet) -> bool { + self.retain_mut(|elem| f(elem)); + } + + /// Retains only the elements specified by the predicate, passing a mutable reference + /// to it. + /// + /// In other words, remove all elements `e` such that `f(&mut e)` returns `false`. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + /// + /// # Examples + /// + /// ``` + /// use janetrs::{array, Janet}; + /// # let _client = janetrs::client::JanetClient::init().unwrap(); + /// + /// let mut array = array![1, 2, 3, 4]; + /// array.retain_mut(|x| { + /// let val = match x.try_unwrap::() { + /// Ok(x) => x, + /// _ => return false, + /// }; + /// + /// if val <= 3 { + /// *x = Janet::integer(val + 1); + /// true + /// } else { + /// false + /// } + /// }); + /// + /// assert!(array.deep_eq(&array![2, 3, 4])); + /// ``` + pub fn retain_mut(&mut self, mut f: F) + where F: FnMut(&mut Janet) -> bool { + // Array: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked] + // |<- processed len ->| ^- next to check + // |<- deleted cnt ->| + // |<- original_len ->| + // Kept: Elements which predicate returns true on. + // Hole: Moved or dropped element slot. + // Unchecked: Unchecked valid elements. + // + // This drop guard will be invoked when predicate or `drop` of element panicked. + // It shifts unchecked elements to cover holes and `set_len` to the correct length. + // In cases when predicate and `drop` never panick, it will be optimized out. + struct BackshiftOnDrop<'a, 'data> { + v: &'a mut JanetArray<'data>, + processed_len: usize, + deleted_cnt: usize, + original_len: usize, + } + + impl<'a, 'data> Drop for BackshiftOnDrop<'a, 'data> { + fn drop(&mut self) { + if self.deleted_cnt > 0 { + // SAFETY: Trailing unchecked items must be valid since we never touch them. + unsafe { + ptr::copy( + self.v.as_ptr().add(self.processed_len), + self.v + .as_mut_ptr() + .add(self.processed_len - self.deleted_cnt), + self.original_len - self.processed_len, + ); + } + } + self.v + .set_len((self.original_len - self.deleted_cnt) as i32); + } + } + + fn process_loop( + original_len: usize, + f: &mut F, + g: &mut BackshiftOnDrop<'_, '_>, + ) where + F: FnMut(&mut Janet) -> bool, + { + while g.processed_len != original_len { + // SAFETY: Unchecked element must be valid. + let cur = unsafe { &mut *g.v.as_mut_ptr().add(g.processed_len) }; + if !f(cur) { + // Advance early to avoid double drop if `drop_in_place` panicked. + g.processed_len += 1; + g.deleted_cnt += 1; + // SAFETY: We never touch this element again after dropped. + // unsafe { ptr::drop_in_place(cur) }; + // We already advanced the counter. + if DELETED { + continue; + } else { + break; + } + } + if DELETED { + // SAFETY: `deleted_cnt` > 0, so the hole slot must not overlap with + // current element. We use copy for move, and + // never touch this element again. + // + unsafe { + let hole_slot = g.v.as_mut_ptr().add(g.processed_len - g.deleted_cnt); + ptr::copy_nonoverlapping(cur, hole_slot, 1); + } + } + g.processed_len += 1; + } + } + + let original_len = self.len() as usize; + let mut g = BackshiftOnDrop { + v: self, + processed_len: 0, + deleted_cnt: 0, + original_len, + }; + + // Stage 1: Nothing was deleted. + process_loop::(original_len, &mut f, &mut g); + + // Stage 2: Some elements were deleted. + process_loop::(original_len, &mut f, &mut g); + + // All item are processed. + drop(g); + } + /// Shortens the array, keeping the first `len` elements and dropping the rest. /// /// If `len` is greater than the array's current length or `len` is lesser than 0, @@ -3066,6 +3219,41 @@ mod tests { Ok(()) } + #[test] + fn retain() -> Result<(), crate::client::Error> { + let _client = JanetClient::init()?; + + let mut array = array![1, 2, 3, 4]; + array.retain(|&x| x.try_unwrap::().map(|x| x % 2 == 0).unwrap_or(false)); + assert_deep_eq!(array, array![2, 4]); + + Ok(()) + } + + #[test] + fn retain_mut() -> Result<(), crate::client::Error> { + let _client = JanetClient::init()?; + + let mut array = array![1, 2, 3, 4]; + array.retain_mut(|x| { + let val = match x.try_unwrap::() { + Ok(x) => x, + _ => return false, + }; + + if val <= 3 { + *x = Janet::integer(val + 1); + true + } else { + false + } + }); + + assert_deep_eq!(array, array![2, 3, 4]); + + Ok(()) + } + #[test] fn dedup_by() -> Result<(), crate::client::Error> { let _client = JanetClient::init().unwrap();