From ede9f485bbf7db2a38e29bfb714c90049302a101 Mon Sep 17 00:00:00 2001 From: yjhmelody <465402634@qq.com> Date: Fri, 18 Dec 2020 12:17:10 +0800 Subject: [PATCH] trait: add ToLuaValue and FromLuaValue --- src/de.rs | 186 ++++++++++++++++++++++++++++++++------------------- src/error.rs | 9 ++- src/lib.rs | 23 ++++++- src/ser.rs | 153 ++++++++++++++++++++++++++---------------- 4 files changed, 237 insertions(+), 134 deletions(-) diff --git a/src/de.rs b/src/de.rs index 0387c17..338ad06 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,11 +1,10 @@ use serde; use serde::de::IntoDeserializer; -use rlua::{Value, TablePairs, TableSequence}; +use rlua::{TablePairs, TableSequence, Value}; use error::{Error, Result}; - pub struct Deserializer<'lua> { pub value: Value<'lua>, } @@ -15,7 +14,8 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { #[inline] fn deserialize_any(self, visitor: V) -> Result - where V: serde::de::Visitor<'de> + where + V: serde::de::Visitor<'de>, { match self.value { Value::Nil => visitor.visit_unit(), @@ -31,16 +31,20 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { if remaining == 0 { Ok(map) } else { - Err(serde::de::Error::invalid_length(len, &"fewer elements in array")) + Err(serde::de::Error::invalid_length( + len, + &"fewer elements in array", + )) } - }, + } _ => Err(serde::de::Error::custom("invalid value type")), } } #[inline] fn deserialize_option(self, visitor: V) -> Result - where V: serde::de::Visitor<'de> + where + V: serde::de::Visitor<'de>, { match self.value { Value::Nil => visitor.visit_none(), @@ -50,26 +54,32 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { #[inline] fn deserialize_enum( - self, _name: &str, _variants: &'static [&'static str], visitor: V + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, ) -> Result - where V: serde::de::Visitor<'de> + where + V: serde::de::Visitor<'de>, { let (variant, value) = match self.value { Value::Table(value) => { let mut iter = value.pairs::(); let (variant, value) = match iter.next() { Some(v) => v?, - None => return Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Map, - &"map with a single key", - )), + None => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )) + } }; if iter.next().is_some() { return Err(serde::de::Error::invalid_value( serde::de::Unexpected::Map, &"map with a single key", - )) + )); } (variant, Some(value)) } @@ -82,7 +92,8 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { #[inline] fn deserialize_seq(self, visitor: V) -> Result - where V: serde::de::Visitor<'de> + where + V: serde::de::Visitor<'de>, { match self.value { Value::Table(v) => { @@ -93,7 +104,10 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { if remaining == 0 { Ok(seq) } else { - Err(serde::de::Error::invalid_length(len, &"fewer elements in array")) + Err(serde::de::Error::invalid_length( + len, + &"fewer elements in array", + )) } } _ => Err(serde::de::Error::custom("invalid value type")), @@ -102,14 +116,21 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { #[inline] fn deserialize_tuple(self, _len: usize, visitor: V) -> Result - where V: serde::de::Visitor<'de> + where + V: serde::de::Visitor<'de>, { self.deserialize_seq(visitor) } #[inline] - fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result - where V: serde::de::Visitor<'de> + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, { self.deserialize_seq(visitor) } @@ -121,19 +142,18 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { } } - struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>); impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result> - where T: serde::de::DeserializeSeed<'de> + where + T: serde::de::DeserializeSeed<'de>, { match self.0.next() { - Some(value) => seed.deserialize(Deserializer { value: value? }) - .map(Some), - None => Ok(None) + Some(value) => seed.deserialize(Deserializer { value: value? }).map(Some), + None => Ok(None), } } @@ -145,17 +165,17 @@ impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> { } } - struct MapDeserializer<'lua>( TablePairs<'lua, Value<'lua>, Value<'lua>>, - Option> + Option>, ); impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> { type Error = Error; fn next_key_seed(&mut self, seed: T) -> Result> - where T: serde::de::DeserializeSeed<'de> + where + T: serde::de::DeserializeSeed<'de>, { match self.0.next() { Some(item) => { @@ -163,13 +183,14 @@ impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> { self.1 = Some(value); let key_de = Deserializer { value: key }; seed.deserialize(key_de).map(Some) - }, + } None => Ok(None), } } fn next_value_seed(&mut self, seed: T) -> Result - where T: serde::de::DeserializeSeed<'de> + where + T: serde::de::DeserializeSeed<'de>, { match self.1.take() { Some(value) => seed.deserialize(Deserializer { value }), @@ -185,7 +206,6 @@ impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> { } } - struct EnumDeserializer<'lua> { variant: String, value: Option>, @@ -196,7 +216,8 @@ impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> { type Variant = VariantDeserializer<'lua>; fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant)> - where T: serde::de::DeserializeSeed<'de> + where + T: serde::de::DeserializeSeed<'de>, { let variant = self.variant.into_deserializer(); let variant_access = VariantDeserializer { value: self.value }; @@ -204,7 +225,6 @@ impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> { } } - struct VariantDeserializer<'lua> { value: Option>, } @@ -218,49 +238,46 @@ impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> { serde::de::Unexpected::NewtypeVariant, &"unit variant", )), - None => Ok(()) + None => Ok(()), } } fn newtype_variant_seed(self, seed: T) -> Result - where T: serde::de::DeserializeSeed<'de> + where + T: serde::de::DeserializeSeed<'de>, { match self.value { Some(value) => seed.deserialize(Deserializer { value }), None => Err(serde::de::Error::invalid_type( serde::de::Unexpected::UnitVariant, &"newtype variant", - )) + )), } } fn tuple_variant(self, _len: usize, visitor: V) -> Result - where V: serde::de::Visitor<'de> + where + V: serde::de::Visitor<'de>, { match self.value { - Some(value) => serde::Deserializer::deserialize_seq( - Deserializer { value }, visitor - ), + Some(value) => serde::Deserializer::deserialize_seq(Deserializer { value }, visitor), None => Err(serde::de::Error::invalid_type( serde::de::Unexpected::UnitVariant, &"tuple variant", - )) + )), } } - fn struct_variant( - self, _fields: &'static [&'static str], visitor: V - ) -> Result - where V: serde::de::Visitor<'de> + fn struct_variant(self, _fields: &'static [&'static str], visitor: V) -> Result + where + V: serde::de::Visitor<'de>, { match self.value { - Some(value) => serde::Deserializer::deserialize_map( - Deserializer { value }, visitor - ), + Some(value) => serde::Deserializer::deserialize_map(Deserializer { value }, visitor), None => Err(serde::de::Error::invalid_type( serde::de::Unexpected::UnitVariant, &"struct variant", - )) + )), } } } @@ -270,6 +287,7 @@ mod tests { use rlua::Lua; use from_value; + use FromLuaValue; #[test] fn test_struct() { @@ -285,20 +303,25 @@ mod tests { int: 1, seq: vec!["a".to_owned(), "b".to_owned()], map: vec![(1, 2), (4, 1)].into_iter().collect(), - empty: vec![] + empty: vec![], }; println!("{:?}", expected); let lua = Lua::new(); lua.context(|lua| { - let value = lua.load(r#" + let value = lua + .load( + r#" a = {} a.int = 1 a.seq = {"a", "b"} a.map = {2, [4]=1} a.empty = {} return a - "#).eval().unwrap(); + "#, + ) + .eval() + .unwrap(); let got = from_value(value).unwrap(); assert_eq!(expected, got); }); @@ -312,20 +335,28 @@ mod tests { let lua = Lua::new(); lua.context(|lua| { let expected = Rgb(1, 2, 3); - let value = lua.load( - r#" + let value = lua + .load( + r#" a = {1, 2, 3} return a - "#).eval().unwrap(); + "#, + ) + .eval() + .unwrap(); let got = from_value(value).unwrap(); assert_eq!(expected, got); let expected = (1, 2, 3); - let value = lua.load( - r#" + let value = lua + .load( + r#" a = {1, 2, 3} return a - "#).eval().unwrap(); + "#, + ) + .eval() + .unwrap(); let got = from_value(value).unwrap(); assert_eq!(expected, got); }); @@ -344,43 +375,58 @@ mod tests { let lua = Lua::new(); lua.context(|lua| { let expected = E::Unit; - let value = lua.load( - r#" + let value = lua + .load( + r#" return "Unit" - "#).eval().unwrap(); + "#, + ) + .eval() + .unwrap(); let got = from_value(value).unwrap(); assert_eq!(expected, got); - let expected = E::Newtype(1); - let value = lua.load( - r#" + let value = lua + .load( + r#" a = {} a["Newtype"] = 1 return a - "#).eval().unwrap(); + "#, + ) + .eval() + .unwrap(); let got = from_value(value).unwrap(); assert_eq!(expected, got); let expected = E::Tuple(1, 2); - let value = lua.load( - r#" + let value = lua + .load( + r#" a = {} a["Tuple"] = {1, 2} return a - "#).eval().unwrap(); + "#, + ) + .eval() + .unwrap(); let got = from_value(value).unwrap(); assert_eq!(expected, got); let expected = E::Struct { a: 1 }; - let value = lua.load( - r#" + let value: rlua::Value = lua + .load( + r#" a = {} a["Struct"] = {} a["Struct"]["a"] = 1 return a - "#).eval().unwrap(); - let got = from_value(value).unwrap(); + "#, + ) + .eval() + .unwrap(); + let got = value.from_value().unwrap(); assert_eq!(expected, got); }); } diff --git a/src/error.rs b/src/error.rs index aa9a234..07c021f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,10 +1,9 @@ -use std::fmt; use std::error::Error as StdError; +use std::fmt; use std::result::Result as StdResult; -use serde; use rlua::Error as LuaError; - +use serde; #[derive(Debug)] pub struct Error(LuaError); @@ -40,7 +39,7 @@ impl serde::ser::Error for Error { Error(LuaError::ToLuaConversionError { from: "serialize", to: "value", - message: Some(format!("{}", msg)) + message: Some(format!("{}", msg)), }) } } @@ -50,7 +49,7 @@ impl serde::de::Error for Error { Error(LuaError::FromLuaConversionError { from: "value", to: "deserialize", - message: Some(format!("{}", msg)) + message: Some(format!("{}", msg)), }) } } diff --git a/src/lib.rs b/src/lib.rs index f69a349..ce59067 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,20 +44,37 @@ extern crate serde; #[macro_use] extern crate serde_derive; +pub mod de; pub mod error; pub mod ser; -pub mod de; +use rlua::{Context, Error, Value}; -use rlua::{Context, Value, Error}; +pub trait ToLuaValue<'lua> { + fn to_value(self, t: T) -> Result, Error>; +} +pub trait FromLuaValue<'lua> { + fn from_value>(self) -> Result; +} + +impl<'lua> ToLuaValue<'lua> for Context<'lua> { + fn to_value(self, t: T) -> Result, Error> { + crate::to_value(self, t) + } +} + +impl<'lua> FromLuaValue<'lua> for Value<'lua> { + fn from_value>(self) -> Result { + crate::from_value::<'lua>(self) + } +} pub fn to_value(lua: Context, t: T) -> Result { let serializer = ser::Serializer { lua }; Ok(t.serialize(serializer)?) } - pub fn from_value<'de, T: serde::Deserialize<'de>>(value: Value<'de>) -> Result { let deserializer = de::Deserializer { value }; Ok(T::deserialize(deserializer)?) diff --git a/src/ser.rs b/src/ser.rs index a12e837..498ca97 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -1,10 +1,9 @@ use serde; -use rlua::{Context, Value, Table, String as LuaString}; +use rlua::{Context, String as LuaString, Table, Value}; -use to_value; use error::{Error, Result}; - +use to_value; pub struct Serializer<'lua> { pub lua: Context<'lua>, @@ -14,12 +13,12 @@ impl<'lua> serde::Serializer for Serializer<'lua> { type Ok = Value<'lua>; type Error = Error; - type SerializeSeq = SerializeVec<'lua>; - type SerializeTuple = SerializeVec<'lua>; - type SerializeTupleStruct = SerializeVec<'lua>; - type SerializeTupleVariant = SerializeTupleVariant<'lua>; - type SerializeMap = SerializeMap<'lua>; - type SerializeStruct = SerializeMap<'lua>; + type SerializeSeq = SerializeVec<'lua>; + type SerializeTuple = SerializeVec<'lua>; + type SerializeTupleStruct = SerializeVec<'lua>; + type SerializeTupleVariant = SerializeTupleVariant<'lua>; + type SerializeMap = SerializeMap<'lua>; + type SerializeStruct = SerializeMap<'lua>; type SerializeStructVariant = SerializeStructVariant<'lua>; #[inline] @@ -91,7 +90,9 @@ impl<'lua> serde::Serializer for Serializer<'lua> { #[inline] fn serialize_bytes(self, value: &[u8]) -> Result> { - Ok(Value::Table(self.lua.create_sequence_from(value.iter().cloned())?)) + Ok(Value::Table( + self.lua.create_sequence_from(value.iter().cloned())?, + )) } #[inline] @@ -106,25 +107,31 @@ impl<'lua> serde::Serializer for Serializer<'lua> { #[inline] fn serialize_unit_variant( - self, _name: &'static str, _variant_index: u32, variant: &'static str + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, ) -> Result> { self.serialize_str(variant) } #[inline] - fn serialize_newtype_struct( - self, _name: &'static str, value: &T - ) -> Result> - where T: ?Sized + serde::Serialize, + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result> + where + T: ?Sized + serde::Serialize, { value.serialize(self) } fn serialize_newtype_variant( - self, _name: &'static str, _variant_index: u32, - variant: &'static str, value: &T, + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, ) -> Result> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { let table = self.lua.create_table()?; let variant = self.lua.create_string(variant)?; @@ -140,7 +147,8 @@ impl<'lua> serde::Serializer for Serializer<'lua> { #[inline] fn serialize_some(self, value: &T) -> Result> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { value.serialize(self) } @@ -159,14 +167,19 @@ impl<'lua> serde::Serializer for Serializer<'lua> { } fn serialize_tuple_struct( - self, _name: &'static str, len: usize, + self, + _name: &'static str, + len: usize, ) -> Result { self.serialize_seq(Some(len)) } fn serialize_tuple_variant( - self, _name: &'static str, _variant_index: u32, - variant: &'static str, _len: usize, + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, ) -> Result { let name = self.lua.create_string(variant)?; let table = self.lua.create_table()?; @@ -174,7 +187,7 @@ impl<'lua> serde::Serializer for Serializer<'lua> { lua: self.lua, idx: 1, name, - table + table, }) } @@ -192,8 +205,11 @@ impl<'lua> serde::Serializer for Serializer<'lua> { } fn serialize_struct_variant( - self, _name: &'static str, _variant_index: u32, - variant: &'static str, _len: usize, + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, ) -> Result { let name = self.lua.create_string(variant)?; let table = self.lua.create_table()?; @@ -203,10 +219,8 @@ impl<'lua> serde::Serializer for Serializer<'lua> { table, }) } - } - pub struct SerializeVec<'lua> { lua: Context<'lua>, table: Table<'lua>, @@ -218,7 +232,8 @@ impl<'lua> serde::ser::SerializeSeq for SerializeVec<'lua> { type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { self.table.set(self.idx, to_value(self.lua, value)?)?; self.idx += 1; @@ -235,7 +250,8 @@ impl<'lua> serde::ser::SerializeTuple for SerializeVec<'lua> { type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { serde::ser::SerializeSeq::serialize_element(self, value) } @@ -250,7 +266,8 @@ impl<'lua> serde::ser::SerializeTupleStruct for SerializeVec<'lua> { type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { serde::ser::SerializeSeq::serialize_element(self, value) } @@ -260,7 +277,6 @@ impl<'lua> serde::ser::SerializeTupleStruct for SerializeVec<'lua> { } } - pub struct SerializeTupleVariant<'lua> { lua: Context<'lua>, name: LuaString<'lua>, @@ -273,7 +289,8 @@ impl<'lua> serde::ser::SerializeTupleVariant for SerializeTupleVariant<'lua> { type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { self.table.set(self.idx, to_value(self.lua, value)?)?; self.idx += 1; @@ -287,11 +304,10 @@ impl<'lua> serde::ser::SerializeTupleVariant for SerializeTupleVariant<'lua> { } } - pub struct SerializeMap<'lua> { lua: Context<'lua>, table: Table<'lua>, - next_key: Option> + next_key: Option>, } impl<'lua> serde::ser::SerializeMap for SerializeMap<'lua> { @@ -299,14 +315,16 @@ impl<'lua> serde::ser::SerializeMap for SerializeMap<'lua> { type Error = Error; fn serialize_key(&mut self, key: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { self.next_key = Some(to_value(self.lua, key)?); Ok(()) } fn serialize_value(&mut self, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { let key = self.next_key.take(); // Panic because this indicates a bug in the program rather than an @@ -326,7 +344,8 @@ impl<'lua> serde::ser::SerializeStruct for SerializeMap<'lua> { type Error = Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { serde::ser::SerializeMap::serialize_key(self, key)?; serde::ser::SerializeMap::serialize_value(self, value) @@ -337,7 +356,6 @@ impl<'lua> serde::ser::SerializeStruct for SerializeMap<'lua> { } } - pub struct SerializeStructVariant<'lua> { lua: Context<'lua>, name: LuaString<'lua>, @@ -349,10 +367,10 @@ impl<'lua> serde::ser::SerializeStructVariant for SerializeStructVariant<'lua> { type Error = Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> - where T: ?Sized + serde::Serialize, + where + T: ?Sized + serde::Serialize, { - self.table - .set(key, to_value(self.lua, value)?)?; + self.table.set(key, to_value(self.lua, value)?)?; Ok(()) } @@ -365,8 +383,9 @@ impl<'lua> serde::ser::SerializeStructVariant for SerializeStructVariant<'lua> { #[cfg(test)] mod tests { - use rlua::Lua; use super::*; + use rlua::Lua; + use ToLuaValue; #[test] fn test_struct() { @@ -376,7 +395,10 @@ mod tests { seq: Vec<&'static str>, } - let test = Test { int: 1, seq: vec!["a", "b"] }; + let test = Test { + int: 1, + seq: vec!["a", "b"], + }; let lua = Lua::new(); lua.context(|lua| { @@ -387,8 +409,11 @@ mod tests { assert(value["int"] == 1) assert(value["seq"][1] == "a") assert(value["seq"][2] == "b") - "#).exec() - }).unwrap() + "#, + ) + .exec() + }) + .unwrap() } #[test] @@ -398,7 +423,7 @@ mod tests { Unit, Newtype(u32), Tuple(u32, u32), - Struct { a: u32}, + Struct { a: u32 }, } let lua = Lua::new(); @@ -407,31 +432,47 @@ mod tests { let u = E::Unit; let value = to_value(lua, &u).unwrap(); lua.globals().set("value", value).unwrap(); - lua.load(r#" + lua.load( + r#" assert(value == "Unit") - "#).exec().unwrap(); + "#, + ) + .exec() + .unwrap(); let n = E::Newtype(1); let value = to_value(lua, &n).unwrap(); lua.globals().set("value", value).unwrap(); - lua.load(r#" + lua.load( + r#" assert(value["Newtype"] == 1) - "#).exec().unwrap(); + "#, + ) + .exec() + .unwrap(); let t = E::Tuple(1, 2); - let value = to_value(lua, &t).unwrap(); + let value = lua.to_value(t).unwrap(); lua.globals().set("value", value).unwrap(); - lua.load(r#" + lua.load( + r#" assert(value["Tuple"][1] == 1) assert(value["Tuple"][2] == 2) - "#).exec().unwrap(); + "#, + ) + .exec() + .unwrap(); let s = E::Struct { a: 1 }; let value = to_value(lua, &s).unwrap(); lua.globals().set("value", value).unwrap(); - lua.load(r#" + lua.load( + r#" assert(value["Struct"]["a"] == 1) - "#).exec() - }).unwrap(); + "#, + ) + .exec() + }) + .unwrap(); } }