Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,4 @@ This work is freely available under the terms of the MIT license.
## Contributors

* [Ryan Marcus](https://rmarcus.info)
* [Richard Barnes](https://richard.science)
98 changes: 51 additions & 47 deletions rmi_lib/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// < begin copyright >
// < begin copyright >
// Copyright Ryan Marcus 2020
//
//
// See root directory of this project for license terms.
//
// < end copyright >
//
// < end copyright >



use crate::models::Model;
use crate::models::*;
Expand Down Expand Up @@ -61,7 +61,7 @@ impl LayerParams {

return LayerParams::Constant(idx, params);
}

fn to_code<T: Write>(&self, target: &mut T) -> Result<(), std::io::Error> {
match self {
LayerParams::Constant(idx, params) => {
Expand Down Expand Up @@ -109,7 +109,7 @@ impl LayerParams {
},
LayerParams::MixedArray(_, _, _) => true,
LayerParams::Constant(_, _) => false,
};
};
}

fn pointer_type(&self) -> &'static str {
Expand All @@ -120,7 +120,7 @@ impl LayerParams {
LayerParams::Constant(_, _) => panic!("No pointer type for constant params")
};
}

fn to_decl<T: Write>(&self, target: &mut T) -> Result<(), std::io::Error> {
match self {
LayerParams::Constant(_, _) => {
Expand All @@ -137,7 +137,7 @@ impl LayerParams {
array_name!(idx),
num_items
)?;
} else {
} else {
writeln!(
target,
"{}* {};",
Expand All @@ -162,7 +162,7 @@ impl LayerParams {


fn write_to<T: Write>(&self, target: &mut T) -> Result<(), std::io::Error> {
match self {
match self {
LayerParams::Array(_idx, _, params) |
LayerParams::MixedArray(_idx, _, params) => {
let (first, rest) = params.split_first().unwrap();
Expand Down Expand Up @@ -238,7 +238,7 @@ impl LayerParams {
write!(target, "{}", array_name!(self.index()))?;
return Result::Ok(());
}

match self {
LayerParams::Constant(idx, _) => {
panic!(
Expand Down Expand Up @@ -267,30 +267,30 @@ impl LayerParams {
for item in params.iter().take(parameter_index) {
offset += item.size();
}

// we have to determine the type of the index being accessed
// and add the appropiate cast.
let c_type = params[parameter_index].c_type();
let ptr_expr = format!("{} + ({} * {}) + {}",
array_name!(idx),
model_index, bytes_per_model,
offset);

write!(target, "*(({new_type}*) ({ptr_expr}))",
new_type=c_type, ptr_expr=ptr_expr)?;

}
};

return Result::Ok(());
}

fn with_zipped_errors(&self, lle: &[u64]) -> LayerParams {

let params = self.params();
// integrate the errors into the model parameters of the last
// layer to save a cache miss.

// TODO we should add padding to make sure each of these are
// cache-aligned. Also a lot of unneeded copying going on here...
let combined_lle_params: Vec<ModelParam> =
Expand All @@ -308,10 +308,10 @@ impl LayerParams {
} else {
false
};

return LayerParams::new(self.index(), is_constant, self.params_per_model() + 1,
combined_lle_params);

}
}

Expand All @@ -327,7 +327,7 @@ impl fmt::Display for LayerParams {
LayerParams::MixedArray(idx, ppm, params) =>
write!(f, "MixedArray(idx: {}, ppm: {}, len: {}, malloc: {})",
idx, ppm, params.len(), self.requires_malloc())

}
}
}
Expand Down Expand Up @@ -377,7 +377,7 @@ pub fn rmi_size(rmi: &TrainedRMI) -> u64 {
let mut num_total_bytes = 0;
for layer in rmi.rmi.iter() {
let model_on_this_layer_size: usize = layer[0].params().iter().map(|p| p.size()).sum();

// assume all models on this layer have the same size
num_total_bytes += model_on_this_layer_size * layer.len();
}
Expand All @@ -389,7 +389,7 @@ pub fn rmi_size(rmi: &TrainedRMI) -> u64 {
if rmi.cache_fix.is_some() {
num_total_bytes += rmi.cache_fix.as_ref().unwrap().1.len() * 16;
}

return num_total_bytes as u64;
}

Expand Down Expand Up @@ -423,8 +423,8 @@ uint64_t lookup(uint64_t key, size_t* err) {{
? num_spline_pts : start + error_on_spline_search);
size_t lower = (error_on_spline_search > start
? 0 : start - error_on_spline_search);


struct SplinePoint* res = std::lower_bound(begin + lower,
begin + upper,
key,
Expand All @@ -442,7 +442,7 @@ uint64_t lookup(uint64_t key, size_t* err) {{
auto t = ((double)(key - pt1.key)) / (double)(pt2.key - pt1.key);
return (((uint64_t) std::fma(1.0 - t, v0, t * v1)) / {3}) * {3};
}}", num_splines, total_keys, array_name, line_size)?;


return Ok(());
}
Expand All @@ -462,7 +462,7 @@ fn generate_code<T: Write>(
.enumerate()
.map(|(layer_idx, models)| params_for_layer(layer_idx, models))
.collect();

let report_last_layer_errors = !rmi.last_layer_max_l1s.is_empty();

let mut report_lle: Vec<u8> = Vec::new();
Expand All @@ -471,14 +471,14 @@ fn generate_code<T: Write>(
if lle.len() > 1 {
let old_last = layer_params.pop().unwrap();
let new_last = old_last.with_zipped_errors(lle);

write!(report_lle, " *err = ")?;
new_last.access_by_ref(&mut report_lle, "modelIndex",
new_last.params_per_model() - 1)?;
writeln!(report_lle, ";")?;

layer_params.push(new_last);

} else {
write!(report_lle, " *err = {};", lle[0])?;
}
Expand All @@ -500,27 +500,28 @@ fn generate_code<T: Write>(
trace!("{}", lps);
}

writeln!(data_output, "namespace {} {{", namespace)?;

writeln!(data_output, "#pragma once\n")?;
writeln!(data_output, "namespace {} {{", namespace)?;

let mut read_code = Vec::new();
read_code.push("bool load(char const* dataPath) {".to_string());

for lp in layer_params.iter() {
match lp {
// constants are put directly in the header
// constants are put directly in the header
LayerParams::Constant(_idx, _) => lp.to_code(data_output)?,

LayerParams::Array(idx, _, _) |
LayerParams::MixedArray(idx, _, _) => {
let data_path = Path::new(&data_dir)
.join(format!("{}_{}", namespace, array_name!(idx)));
let f = File::create(data_path)
.expect("Could not write data file to RMI directory");
let mut bw = BufWriter::new(f);

lp.write_to(&mut bw)?; // write to data file
lp.to_decl(data_output)?; // write to source code

read_code.push(" {".to_string());
read_code.push(format!(" std::ifstream infile(std::filesystem::path(dataPath) / \"{ns}_{fn}\", std::ios::in | std::ios::binary);",
ns=namespace, fn=array_name!(idx)));
Expand Down Expand Up @@ -554,7 +555,7 @@ fn generate_code<T: Write>(
}
panic!();
}

free_code.push("}".to_string());

writeln!(data_output, "}} // namespace")?;
Expand All @@ -574,9 +575,10 @@ fn generate_code<T: Write>(
writeln!(code_output, "#include \"{}_data.h\"", namespace)?;
writeln!(code_output, "#include <math.h>")?;
writeln!(code_output, "#include <cmath>")?;
writeln!(code_output, "#include <fstream>")?;
writeln!(code_output, "#include <filesystem>")?;
writeln!(code_output, "#include <fstream>")?;
writeln!(code_output, "#include <iostream>")?;
writeln!(header_output, "")?; //Blank line after headers
if rmi.cache_fix.is_some() {
writeln!(code_output, "#include <algorithm>")?;
}
Expand All @@ -590,7 +592,7 @@ fn generate_code<T: Write>(
for ln in free_code {
writeln!(code_output, "{}", ln)?;
}

for decl in decls {
writeln!(code_output, "{}", decl)?;
}
Expand Down Expand Up @@ -623,7 +625,7 @@ inline size_t FCLAMP(double inp, double bound) {{
} else {
"_rmi_lookup_pre_cachefix"
};

let lookup_sig = if report_last_layer_errors {
format!("uint64_t {}({} key, size_t* err)", rmi_lookup_name, key_type.c_type())
} else {
Expand Down Expand Up @@ -720,12 +722,14 @@ inline size_t FCLAMP(double inp, double bound) {{
if rmi.cache_fix.is_some() {
generate_cache_fix_code(code_output, &rmi, array_name!(layer_params.len()-1))?;
}

writeln!(code_output, "}} // namespace")?;

// write out our forward declarations
writeln!(header_output, "#pragma once\n")?;
writeln!(header_output, "#include <cstddef>")?;
writeln!(header_output, "#include <cstdint>")?;
writeln!(header_output, "")?; //Blank line after headers
writeln!(header_output, "namespace {} {{", namespace)?;

writeln!(header_output, "bool load(char const* dataPath);")?;
Expand All @@ -748,7 +752,7 @@ inline size_t FCLAMP(double inp, double bound) {{
} else {
writeln!(header_output, "uint64_t lookup(uint64_t key, size_t* err);")?;
}
writeln!(header_output, "}}")?;
writeln!(header_output, "}} // namespace")?;

return Result::Ok(());
}
Expand All @@ -759,14 +763,14 @@ pub fn output_rmi(namespace: &str,
data_dir: &str,
key_type: KeyType,
include_errors: bool) -> Result<(), std::io::Error> {

let f1 = File::create(format!("{}.cpp", namespace)).expect("Could not write RMI CPP file");
let mut bw1 = BufWriter::new(f1);

let f2 =
File::create(format!("{}_data.h", namespace)).expect("Could not write RMI data file");
let mut bw2 = BufWriter::new(f2);

let f3 = File::create(format!("{}.h", namespace)).expect("Could not write RMI header file");
let mut bw3 = BufWriter::new(f3);

Expand All @@ -783,6 +787,6 @@ pub fn output_rmi(namespace: &str,
data_dir,
key_type
);


}