diff --git a/sip7/src/lib.rs b/sip7/src/lib.rs index 3455288..3b4f098 100644 --- a/sip7/src/lib.rs +++ b/sip7/src/lib.rs @@ -6,8 +6,9 @@ use alloc::string::String; use alloc::vec::Vec; use core::fmt; -const TYPE_TXT: u8 = 0x00; -const TYPE_BLOB: u8 = 0x01; +const TYPE_SEQ: u8 = 0x00; +const TYPE_TXT: u8 = 0x01; +const TYPE_BLOB: u8 = 0x02; /// Errors that can occur during record parsing or construction. #[derive(Clone, Debug, PartialEq, Eq)] @@ -18,17 +19,25 @@ pub enum Error { KeyTooLong, InvalidKey, InvalidUtf8, + SeqNotFirst, + DuplicateSeq, } /// A single record in a SIP-7 record set. #[derive(Clone, Debug, PartialEq, Eq)] pub enum Record { + Seq(u64), Txt { key: String, value: String }, Blob { key: String, value: Vec }, Unknown { rtype: u8, rdata: Vec }, } impl Record { + /// Creates a Seq record with the given version. + pub fn seq(version: u64) -> Self { + Record::Seq(version) + } + /// Creates a TXT record. pub fn txt(key: &str, value: &str) -> Self { Record::Txt { @@ -70,7 +79,7 @@ impl Record { let rtype = data[pos]; pos += 1; - let len = read_compact_size(data, &mut pos)?; + let len = read_compact_size(data, &mut pos)? as usize; if pos + len > data.len() { return Err(Error::DataOverflow); } @@ -78,6 +87,11 @@ impl Record { pos += len; let record = match rtype { + TYPE_SEQ => { + let mut rpos = 0; + let version = read_compact_size(rdata, &mut rpos)?; + Record::Seq(version) + } TYPE_TXT => { let (key, val_bytes) = parse_kv(rdata)?; let value = @@ -105,11 +119,18 @@ impl Record { fn pack_into(&self, buf: &mut Vec) -> Result<(), Error> { match self { + Record::Seq(version) => { + buf.push(TYPE_SEQ); + let mut version_buf = Vec::new(); + write_compact_size(&mut version_buf, *version); + write_compact_size(buf, version_buf.len() as u64); + buf.extend_from_slice(&version_buf); + } Record::Txt { key, value } => { validate_key(key)?; buf.push(TYPE_TXT); let data_len = 1 + key.len() + value.len(); - write_compact_size(buf, data_len); + write_compact_size(buf, data_len as u64); buf.push(key.len() as u8); buf.extend_from_slice(key.as_bytes()); buf.extend_from_slice(value.as_bytes()); @@ -118,14 +139,14 @@ impl Record { validate_key(key)?; buf.push(TYPE_BLOB); let data_len = 1 + key.len() + value.len(); - write_compact_size(buf, data_len); + write_compact_size(buf, data_len as u64); buf.push(key.len() as u8); buf.extend_from_slice(key.as_bytes()); buf.extend_from_slice(value); } Record::Unknown { rtype, rdata } => { buf.push(*rtype); - write_compact_size(buf, rdata.len()); + write_compact_size(buf, rdata.len() as u64); buf.extend_from_slice(rdata); } } @@ -147,10 +168,25 @@ impl RecordSet { } /// Packs a collection of records into a record set. + /// + /// If a `Seq` record is present it must be the first element + /// and there must be at most one. pub fn pack(records: impl IntoIterator) -> Result { let mut data = Vec::new(); + let mut seen_seq = false; + let mut index = 0usize; for record in records { + if matches!(record, Record::Seq(_)) { + if seen_seq { + return Err(Error::DuplicateSeq); + } + if index > 0 { + return Err(Error::SeqNotFirst); + } + seen_seq = true; + } record.pack_into(&mut data)?; + index += 1; } Ok(Self(data)) } @@ -164,6 +200,8 @@ impl RecordSet { pub fn iter(&self) -> RecordIter<'_> { RecordIter { data: self.0.as_slice(), + index: 0, + seen_seq: false, } } @@ -181,11 +219,22 @@ impl RecordSet { pub fn is_empty(&self) -> bool { self.0.is_empty() } + + /// Returns the Seq version if the first record is a Seq. + /// Only parses the first record; does not validate the rest. + pub fn seq(&self) -> Option { + match Record::unpack(self.0.as_slice()) { + Ok(Some((Record::Seq(version), _))) => Some(version), + _ => None, + } + } } /// An iterator that lazily unpacks records from a byte slice. pub struct RecordIter<'a> { data: &'a [u8], + index: usize, + seen_seq: bool, } impl<'a> Iterator for RecordIter<'a> { @@ -197,7 +246,19 @@ impl<'a> Iterator for RecordIter<'a> { } match Record::unpack(self.data) { Ok(Some((record, consumed))) => { + if matches!(record, Record::Seq(_)) { + if self.seen_seq { + self.data = &[]; + return Some(Err(Error::DuplicateSeq)); + } + if self.index > 0 { + self.data = &[]; + return Some(Err(Error::SeqNotFirst)); + } + self.seen_seq = true; + } self.data = &self.data[consumed..]; + self.index += 1; Some(Ok(record)) } Ok(None) => None, @@ -218,6 +279,8 @@ impl fmt::Display for Error { Error::KeyTooLong => write!(f, "key length exceeds 255 bytes"), Error::InvalidKey => write!(f, "key must be lowercase ascii, digits, or hyphens"), Error::InvalidUtf8 => write!(f, "invalid UTF-8 in text value"), + Error::SeqNotFirst => write!(f, "seq record must be the first record"), + Error::DuplicateSeq => write!(f, "only one seq record is allowed"), } } } @@ -238,19 +301,19 @@ fn validate_key(key: &str) -> Result<(), Error> { Ok(()) } -fn read_compact_size(data: &[u8], pos: &mut usize) -> Result { +fn read_compact_size(data: &[u8], pos: &mut usize) -> Result { if *pos >= data.len() { return Err(Error::UnexpectedEof); } let first = data[*pos]; *pos += 1; match first { - 0x00..=0xFC => Ok(first as usize), + 0x00..=0xFC => Ok(first as u64), 0xFD => { if *pos + 2 > data.len() { return Err(Error::UnexpectedEof); } - let v = u16::from_le_bytes([data[*pos], data[*pos + 1]]) as usize; + let v = u16::from_le_bytes([data[*pos], data[*pos + 1]]) as u64; *pos += 2; Ok(v) } @@ -258,7 +321,7 @@ fn read_compact_size(data: &[u8], pos: &mut usize) -> Result { if *pos + 4 > data.len() { return Err(Error::UnexpectedEof); } - let v = u32::from_le_bytes(data[*pos..*pos + 4].try_into().unwrap()) as usize; + let v = u32::from_le_bytes(data[*pos..*pos + 4].try_into().unwrap()) as u64; *pos += 4; Ok(v) } @@ -267,14 +330,13 @@ fn read_compact_size(data: &[u8], pos: &mut usize) -> Result { return Err(Error::UnexpectedEof); } let v = u64::from_le_bytes(data[*pos..*pos + 8].try_into().unwrap()); - let v: usize = v.try_into().map_err(|_| Error::DataOverflow)?; *pos += 8; Ok(v) } } } -fn write_compact_size(buf: &mut Vec, value: usize) { +fn write_compact_size(buf: &mut Vec, value: u64) { if value <= 0xFC { buf.push(value as u8); } else if value <= 0xFFFF { @@ -285,7 +347,7 @@ fn write_compact_size(buf: &mut Vec, value: usize) { buf.extend_from_slice(&(value as u32).to_le_bytes()); } else { buf.push(0xFF); - buf.extend_from_slice(&(value as u64).to_le_bytes()); + buf.extend_from_slice(&value.to_le_bytes()); } } @@ -310,6 +372,11 @@ mod serde_impl { use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; + #[derive(Serialize, Deserialize)] + struct SeqJson { + version: u64, + } + #[derive(Serialize, Deserialize)] struct TxtJson { key: String, @@ -331,6 +398,8 @@ mod serde_impl { #[derive(Serialize, Deserialize)] #[serde(tag = "type")] enum RecordJson { + #[serde(rename = "seq")] + Seq(SeqJson), #[serde(rename = "txt")] Txt(TxtJson), #[serde(rename = "blob")] @@ -342,6 +411,7 @@ mod serde_impl { impl From<&Record> for RecordJson { fn from(r: &Record) -> Self { match r { + Record::Seq(version) => RecordJson::Seq(SeqJson { version: *version }), Record::Txt { key, value } => RecordJson::Txt(TxtJson { key: key.clone(), value: value.clone(), @@ -363,6 +433,7 @@ mod serde_impl { fn try_from(j: RecordJson) -> Result { match j { + RecordJson::Seq(s) => Ok(Record::seq(s.version)), RecordJson::Txt(t) => Ok(Record::txt(&t.key, &t.value)), RecordJson::Blob(b) => { let value = BASE64_STANDARD @@ -452,6 +523,91 @@ mod tests { use alloc::vec; use super::*; + #[test] + fn pack_unpack_seq() { + let rs = RecordSet::pack(vec![ + Record::seq(1), + Record::txt("btc", "bc1qtest"), + ]).unwrap(); + + let records = rs.unpack().unwrap(); + assert_eq!(records.len(), 2); + assert_eq!(records[0], Record::Seq(1)); + } + + #[test] + fn seq_large_version() { + let rs = RecordSet::pack(vec![Record::seq(1000)]).unwrap(); + let records = rs.unpack().unwrap(); + assert_eq!(records[0], Record::Seq(1000)); + } + + #[test] + fn seq_u64_max() { + let rs = RecordSet::pack(vec![Record::seq(u64::MAX)]).unwrap(); + let records = rs.unpack().unwrap(); + assert_eq!(records[0], Record::Seq(u64::MAX)); + } + + #[test] + fn seq_must_be_first_pack() { + let err = RecordSet::pack(vec![ + Record::txt("btc", "bc1q"), + Record::seq(1), + ]).unwrap_err(); + assert_eq!(err, Error::SeqNotFirst); + } + + #[test] + fn seq_must_be_first_unpack() { + // Manually craft bytes: TXT record then SEQ record + let mut data = Record::txt("a", "b").pack().unwrap(); + data.extend_from_slice(&Record::seq(1).pack().unwrap()); + let rs = RecordSet::new(data); + assert_eq!(rs.unpack(), Err(Error::SeqNotFirst)); + } + + #[test] + fn duplicate_seq_pack() { + let err = RecordSet::pack(vec![ + Record::seq(1), + Record::seq(2), + ]).unwrap_err(); + assert_eq!(err, Error::DuplicateSeq); + } + + #[test] + fn duplicate_seq_unpack() { + // Manually craft bytes: two SEQ records + let mut data = Record::seq(1).pack().unwrap(); + data.extend_from_slice(&Record::seq(2).pack().unwrap()); + let rs = RecordSet::new(data); + assert_eq!(rs.unpack(), Err(Error::DuplicateSeq)); + } + + #[test] + fn seq_helper_returns_version() { + let rs = RecordSet::pack(vec![ + Record::seq(42), + Record::txt("btc", "bc1q"), + ]).unwrap(); + assert_eq!(rs.seq(), Some(42)); + } + + #[test] + fn no_seq_returns_none() { + let rs = RecordSet::pack(vec![ + Record::txt("btc", "bc1q"), + ]).unwrap(); + assert_eq!(rs.seq(), None); + } + + #[test] + fn empty_set_seq_returns_none() { + let rs = RecordSet::default(); + assert_eq!(rs.seq(), None); + } + #[test] fn pack_unpack_txt() { let rs = RecordSet::pack(vec![ @@ -503,6 +659,7 @@ mod tests { #[test] fn pack_unpack_mixed() { let rs = RecordSet::pack(vec![ + Record::seq(1), Record::txt("btc", "bc1qtest"), Record::blob("data", vec![0xFF, 0x00]), Record::unknown(0x10, vec![0xAB]), @@ -510,7 +667,8 @@ mod tests { ]).unwrap(); let records = rs.unpack().unwrap(); - assert_eq!(records.len(), 4); + assert_eq!(records.len(), 5); + assert_eq!(records[0], Record::Seq(1)); } #[test] @@ -661,6 +819,21 @@ mod tests { use alloc::vec; use super::*; + #[test] + fn json_round_trip_seq() { + let rs = RecordSet::pack(vec![ + Record::seq(1), + Record::txt("btc", "bc1qtest"), + ]).unwrap(); + + let json = serde_json::to_string(&rs).unwrap(); + assert!(json.contains("\"type\":\"seq\"")); + assert!(json.contains("\"version\":1")); + + let decoded: RecordSet = serde_json::from_str(&json).unwrap(); + assert_eq!(rs.unpack().unwrap(), decoded.unpack().unwrap()); + } + #[test] fn json_round_trip_txt() { let rs = RecordSet::pack(vec![ @@ -702,6 +875,7 @@ mod tests { #[test] fn json_matches_spec_format() { let json = r#"[ + {"type":"seq","version":1}, {"type":"txt","key":"btc","value":"bc1q..."}, {"type":"blob","key":"some-data","value":"aGVsbG8="}, {"type":"unknown","rtype":42,"rdata":"AQID"} @@ -709,9 +883,10 @@ mod tests { let rs: RecordSet = serde_json::from_str(json).unwrap(); let records = rs.unpack().unwrap(); - assert_eq!(records.len(), 3); + assert_eq!(records.len(), 4); + assert_eq!(records[0], Record::Seq(1)); - match &records[0] { + match &records[1] { Record::Txt { key, value } => { assert_eq!(key, "btc"); assert_eq!(value, "bc1q..."); @@ -719,7 +894,7 @@ mod tests { _ => panic!("expected txt"), } - match &records[1] { + match &records[2] { Record::Blob { key, value } => { assert_eq!(key, "some-data"); assert_eq!(value, b"hello"); @@ -727,7 +902,7 @@ mod tests { _ => panic!("expected blob"), } - match &records[2] { + match &records[3] { Record::Unknown { rtype, rdata } => { assert_eq!(*rtype, 42); assert_eq!(rdata, &[1, 2, 3]); @@ -748,5 +923,15 @@ mod tests { let decoded: RecordSet = serde_json::from_str(&json).unwrap(); assert_eq!(rs.unpack().unwrap(), decoded.unpack().unwrap()); } + + #[test] + fn json_seq_not_first_rejected() { + let json = r#"[ + {"type":"txt","key":"btc","value":"bc1q..."}, + {"type":"seq","version":1} + ]"#; + let err = serde_json::from_str::(json); + assert!(err.is_err()); + } } }