diff --git a/Cargo.toml b/Cargo.toml index db1997e..b75d4fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ workspace = { members = ["libshvproto-macros"] } [package] name = "shvproto" -version = "3.0.10" +version = "3.0.11" edition = "2021" [dependencies] diff --git a/src/chainpack.rs b/src/chainpack.rs index ff83ae2..7980f14 100644 --- a/src/chainpack.rs +++ b/src/chainpack.rs @@ -193,9 +193,8 @@ impl<'a, W> ChainPackWriter<'a, W> } fn write_decimal(&mut self, decimal: &Decimal) -> WriteResult { let cnt = self.write_byte(PackingSchema::Decimal as u8)?; - let (mantisa, exponent) = decimal.decode(); - self.write_int_data(mantisa)?; - self.write_int_data(exponent as i64)?; + self.write_int_data(decimal.mantissa())?; + self.write_int_data(decimal.exponent() as i64)?; Ok(self.byte_writer.count() - cnt) } fn write_datetime(&mut self, dt: &DateTime) -> WriteResult { diff --git a/src/cpon.rs b/src/cpon.rs index fb4440a..fa18cf2 100644 --- a/src/cpon.rs +++ b/src/cpon.rs @@ -299,8 +299,7 @@ impl<'a, W> CponWriter<'a, W> impl<'a, W> Writer for CponWriter<'a, W> where W: Write { - fn write(&mut self, val: &RpcValue) -> WriteResult - { + fn write(&mut self, val: &RpcValue) -> WriteResult { let cnt = self.byte_writer.count(); let mm = val.meta(); if !mm.is_empty() { @@ -309,8 +308,7 @@ impl<'a, W> Writer for CponWriter<'a, W> self.write_value(val.value())?; Ok(self.byte_writer.count() - cnt) } - fn write_meta(&mut self, map: &MetaMap) -> WriteResult - { + fn write_meta(&mut self, map: &MetaMap) -> WriteResult { let cnt: usize = self.byte_writer.count(); let is_oneliner = Self::is_oneliner_meta(map); self.write_byte(b'<')?; @@ -335,8 +333,7 @@ impl<'a, W> Writer for CponWriter<'a, W> self.write_byte(b'>')?; Ok(self.byte_writer.count() - cnt) } - fn write_value(&mut self, val: &Value) -> WriteResult - { + fn write_value(&mut self, val: &Value) -> WriteResult { let cnt: usize = self.byte_writer.count(); match val { Value::Null => self.write_bytes("null".as_bytes()), @@ -538,13 +535,25 @@ impl<'a, R> CponReader<'a, R> } Ok(Value::from(buff)) } - fn read_int(&mut self, no_signum: bool) -> Result<(u64, bool, i32), ReadError> - { + fn read_int(&mut self, no_signum: bool) -> Result<(i64, bool, i32), ReadError> { let mut base = 10; - let mut val: u64 = 0; + let mut val: i64 = 0; let mut neg = false; let mut n = 0; let mut digit_cnt = 0; + fn add_digit(val: &mut i64, base: i64, digit: i64) -> i32 { + if let Some(val1) = val.checked_mul(base) { + *val = val1; + } else { + return 0; + } + if let Some(val1) = val.checked_add(digit) { + *val = val1; + 1 + } else { + 0 + } + } loop { let b = self.peek_byte(); match b { @@ -573,28 +582,21 @@ impl<'a, R> CponReader<'a, R> } b'0' ..= b'9' => { self.get_byte()?; - val *= base; - //log::debug!("val: {:x} {}", val, (b as i64)); - val += (b - b'0') as u64; - digit_cnt += 1; + digit_cnt += add_digit(&mut val, base, (b - b'0') as i64); } b'A' ..= b'F' => { if base != 16 { break; } self.get_byte()?; - val *= base; - val += (b - b'A') as u64 + 10; - digit_cnt += 1; + digit_cnt += add_digit(&mut val, base, (b - b'A') as i64 + 10); } b'a' ..= b'f' => { if base != 16 { break; } self.get_byte()?; - val *= base; - val += (b - b'a') as u64 + 10; - digit_cnt += 1; + digit_cnt += add_digit(&mut val, base, (b - b'a') as i64 + 10); } _ => break, } @@ -602,8 +604,7 @@ impl<'a, R> CponReader<'a, R> } Ok((val, neg, digit_cnt)) } - fn read_number(&mut self) -> Result - { + fn read_number(&mut self) -> Result { let mut mantisa; let mut exponent = 0; let mut decimals = 0; @@ -628,8 +629,8 @@ impl<'a, R> CponReader<'a, R> } mantisa = n; #[derive(PartialEq)] - enum State { Mantisa, Decimals, } - let mut state = State::Mantisa; + enum State { Mantissa, Decimals, } + let mut state = State::Mantissa; loop { let b = self.peek_byte(); match b { @@ -639,7 +640,7 @@ impl<'a, R> CponReader<'a, R> break; } b'.' => { - if state != State::Mantisa { + if state != State::Mantissa { return Err(self.make_error("Unexpected decimal point.", ReadErrorReason::InvalidCharacter)) } state = State::Decimals; @@ -650,17 +651,17 @@ impl<'a, R> CponReader<'a, R> dec_cnt = digit_cnt as i64; } b'e' | b'E' => { - if state != State::Mantisa && state != State::Decimals { - return Err(self.make_error("Unexpected exponet mark.", ReadErrorReason::InvalidCharacter)) + if state != State::Mantissa && state != State::Decimals { + return Err(self.make_error("Unexpected exponent mark.", ReadErrorReason::InvalidCharacter)) } //state = State::Exponent; is_decimal = true; self.get_byte()?; let (n, neg, digit_cnt) = self.read_int(false)?; - exponent = n as i64; + exponent = n; if neg { exponent = -exponent; } if digit_cnt == 0 { - return Err(self.make_error("Malformed number exponetional part.", ReadErrorReason::InvalidCharacter)) + return Err(self.make_error("Malformed number exponential part.", ReadErrorReason::InvalidCharacter)) } break; } @@ -669,22 +670,21 @@ impl<'a, R> CponReader<'a, R> } if is_decimal { for _i in 0 .. dec_cnt { - mantisa *= 10; + mantisa = mantisa.checked_mul(10).unwrap_or(mantisa); } - mantisa += decimals; - let mut snum = mantisa as i64; + mantisa = mantisa.checked_add(decimals).unwrap_or(mantisa); + let mut snum = mantisa; if is_neg { snum = -snum } return Ok(Value::from(Decimal::new(snum, (exponent - dec_cnt) as i8))) } if is_uint { - return Ok(Value::from(mantisa)) + return Ok(Value::from(mantisa as u64)) } - let mut snum = mantisa as i64; + let mut snum = mantisa; if is_neg { snum = -snum } Ok(Value::from(snum)) } - fn read_list(&mut self) -> Result - { + fn read_list(&mut self) -> Result { let mut lst = Vec::new(); self.get_byte()?; // eat '[' loop { @@ -743,7 +743,7 @@ impl<'a, R> CponReader<'a, R> break; } let (k, neg, _) = self.read_int(false)?; - let key = if neg { -(k as i64) } else { k as i64 }; + let key = if neg { -{ k } } else { k }; self.skip_white_insignificant()?; let val = self.read()?; map.insert(key as i32, val); @@ -850,25 +850,32 @@ mod test use crate::cpon::CponReader; use crate::reader::Reader; use crate::rpcvalue::Map; - #[test] fn test_read() { - assert!(RpcValue::from_cpon("null").unwrap().is_null()); - assert!(!RpcValue::from_cpon("false").unwrap().as_bool()); - assert!(RpcValue::from_cpon("true").unwrap().as_bool()); + fn test_cpon_round_trip(cpon: &str, val: T) where RpcValue: From { + let rv1 = RpcValue::from_cpon(cpon).unwrap(); + let rv2 = RpcValue::from(val); + assert_eq!(rv1, rv2); + let cpon2 = rv1.to_cpon(); + assert_eq!(cpon, &cpon2); + } + test_cpon_round_trip("null", RpcValue::null()); + test_cpon_round_trip("false", false); + test_cpon_round_trip("true", true); assert_eq!(RpcValue::from_cpon("0").unwrap().as_i32(), 0); assert_eq!(RpcValue::from_cpon("123").unwrap().as_i32(), 123); - assert_eq!(RpcValue::from_cpon("-123").unwrap().as_i32(), -123); + test_cpon_round_trip("-123", -123); assert_eq!(RpcValue::from_cpon("+123").unwrap().as_i32(), 123); - assert_eq!(RpcValue::from_cpon("123u").unwrap().as_u32(), 123u32); + test_cpon_round_trip("123u", 123u32); assert_eq!(RpcValue::from_cpon("0xFF").unwrap().as_i32(), 255); assert_eq!(RpcValue::from_cpon("-0x1000").unwrap().as_i32(), -4096); - assert_eq!(RpcValue::from_cpon("123.4").unwrap().as_decimal(), Decimal::new(1234, -1)); - assert_eq!(RpcValue::from_cpon("0.123").unwrap().as_decimal(), Decimal::new(123, -3)); + test_cpon_round_trip("123.4", Decimal::new(1234, -1)); + test_cpon_round_trip("0.123", Decimal::new(123, -3)); assert_eq!(RpcValue::from_cpon("-0.123").unwrap().as_decimal(), Decimal::new(-123, -3)); assert_eq!(RpcValue::from_cpon("0e0").unwrap().as_decimal(), Decimal::new(0, 0)); assert_eq!(RpcValue::from_cpon("0.123e3").unwrap().as_decimal(), Decimal::new(123, 0)); - assert_eq!(RpcValue::from_cpon("1000000.").unwrap().as_decimal(), Decimal::new(1000000, 0)); + test_cpon_round_trip("1000000.", Decimal::new(1000000, 0)); + test_cpon_round_trip("50.031387414025325", Decimal::new(50031387414025325, -15)); assert_eq!(RpcValue::from_cpon(r#""foo""#).unwrap().as_str(), "foo"); assert_eq!(RpcValue::from_cpon(r#""ěščřžýáí""#).unwrap().as_str(), "ěščřžýáí"); assert_eq!(RpcValue::from_cpon("b\"foo\tbar\nbaz\"").unwrap().as_blob(), b"foo\tbar\nbaz"); @@ -952,5 +959,12 @@ mod test //let cpon2 = rv.to_cpon_string().unwrap(); //assert_eq!(cpon1, cpon2); } + #[test] + fn test_read_too_long_numbers() { + // read very long decimal without overflow error, value is capped + assert!(RpcValue::from_cpon("123456789012345678901234567890123456789012345678901234567890").unwrap().is_int()); + assert!(RpcValue::from_cpon("1.23456789012345678901234567890123456789012345678901234567890").unwrap().is_decimal()); + assert!(RpcValue::from_cpon("123456789012345678901234567890123456789012345678901234567890.").unwrap().is_decimal()); + } } diff --git a/src/decimal.rs b/src/decimal.rs index e11f66c..d321279 100644 --- a/src/decimal.rs +++ b/src/decimal.rs @@ -3,38 +3,31 @@ /// mantisa: 56, exponent: 8; /// I'm storing whole Decimal in one i64 to keep size_of RpcValue == 24 #[derive(Debug, Copy, Clone, PartialEq)] -pub struct Decimal (i64); +pub struct Decimal { + mantissa: i64, + exponent: i8, +} impl Decimal { - pub fn new(mantisa: i64, exponent: i8) -> Decimal { - //log::debug!("\t mantisa: {} {:b}", mantisa, mantisa); - let mut n = mantisa << 8; - //log::debug!("\t 1antisa: {} {:b}", n, n); - n |= (exponent as i64) & 0xff; - //log::debug!("\t 2antisa: {} {:b}", n, n); - Decimal(n) - } - pub fn decode(&self) -> (i64, i8) { - let m = self.0 >> 8; - let e = self.0 as i8; - (m, e) + pub fn new(mantissa: i64, exponent: i8) -> Decimal { + Decimal{ mantissa, exponent } } pub fn mantissa(&self) -> i64 { - self.decode().0 + self.mantissa } pub fn exponent(&self) -> i8 { - self.decode().1 + self.exponent } pub fn to_cpon_string(&self) -> String { let mut neg = false; - let (mut mantisa, exponent) = self.decode(); - if mantisa < 0 { - mantisa = -mantisa; + let (mut mantissa, exponent) = (self.mantissa, self.exponent); + if mantissa < 0 { + mantissa = -mantissa; neg = true; } //let buff: Vec = Vec::new(); - let mut s = mantisa.to_string(); + let mut s = mantissa.to_string(); let n = s.len() as i8; let dec_places = -exponent; @@ -70,17 +63,16 @@ impl Decimal { s } pub fn to_f64(&self) -> f64 { - let (m, e) = self.decode(); - let mut d = m as f64; + let mut d = self.mantissa as f64; // We probably don't want to call .cmp() because of performance loss #[allow(clippy::comparison_chain)] - if e < 0 { - for _ in e .. 0 { + if self.exponent < 0 { + for _ in self.exponent .. 0 { d /= 10.; } } - else if e > 0 { - for _ in 0 .. e { + else if self.exponent > 0 { + for _ in 0 .. self.exponent { d *= 10.; } } diff --git a/src/rpcvalue.rs b/src/rpcvalue.rs index f8fe9c5..45a6187 100644 --- a/src/rpcvalue.rs +++ b/src/rpcvalue.rs @@ -946,6 +946,7 @@ impl RpcValue { is_xxx!(is_null, Value::Null); is_xxx!(is_bool, Value::Bool(_)); is_xxx!(is_int, Value::Int(_)); + is_xxx!(is_decimal, Value::Decimal(_)); is_xxx!(is_string, Value::String(_)); is_xxx!(is_blob, Value::Blob(_)); is_xxx!(is_list, Value::List(_));