diff --git a/clarity/src/vm/ast/mod.rs b/clarity/src/vm/ast/mod.rs index b09de172e6..371adedbcc 100644 --- a/clarity/src/vm/ast/mod.rs +++ b/clarity/src/vm/ast/mod.rs @@ -17,6 +17,8 @@ pub mod definition_sorter; pub mod expression_identifier; pub mod parser; +#[cfg(feature = "developer-mode")] +pub mod static_cost; pub mod traits_resolver; pub mod errors; diff --git a/clarity/src/vm/ast/static_cost/mod.rs b/clarity/src/vm/ast/static_cost/mod.rs new file mode 100644 index 0000000000..25449360d6 --- /dev/null +++ b/clarity/src/vm/ast/static_cost/mod.rs @@ -0,0 +1,218 @@ +mod trait_counter; +use std::collections::HashMap; + +use clarity_types::representations::SymbolicExpression; +use clarity_types::types::{CharType, SequenceData}; +use stacks_common::types::StacksEpochId; +pub use trait_counter::{ + TraitCount, TraitCountCollector, TraitCountContext, TraitCountPropagator, TraitCountVisitor, +}; + +use crate::vm::callables::CallableType; +use crate::vm::costs::analysis::{ + CostAnalysisNode, CostExprNode, StaticCost, SummingExecutionCost, +}; +use crate::vm::costs::cost_functions::linear; +use crate::vm::costs::ExecutionCost; +use crate::vm::errors::VmExecutionError; +use crate::vm::functions::{lookup_reserved_functions, NativeFunctions}; +use crate::vm::representations::ClarityName; +use crate::vm::{ClarityVersion, Value}; +use crate::vm::functions::special_costs; + +const STRING_COST_BASE: u64 = 36; +const STRING_COST_MULTIPLIER: u64 = 3; + +pub(crate) fn calculate_function_cost( + function_name: String, + cost_map: &HashMap>, + _clarity_version: &ClarityVersion, +) -> Result { + match cost_map.get(&function_name) { + Some(Some(cost)) => { + // Cost already computed + Ok(cost.clone()) + } + Some(None) => { + // Should be impossible.. + // Function exists but cost not yet computed, circular dependency? + // For now, return zero cost to avoid infinite recursion + println!( + "Circular dependency detected for function: {}", + function_name + ); + Ok(StaticCost::ZERO) + } + None => { + // Function not found + Ok(StaticCost::ZERO) + } + } +} + +/// Determine if a function name represents a branching function +pub(crate) fn is_branching_function(function_name: &ClarityName) -> bool { + match function_name.as_str() { + "if" | "match" => true, + "unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and + // unwrap-err traverse both branches regardless of result, so until this is + // fixed in clarity we'll set this to false + _ => false, + } +} + +pub(crate) fn is_node_branching(node: &CostAnalysisNode) -> bool { + match &node.expr { + CostExprNode::NativeFunction(NativeFunctions::If) + | CostExprNode::NativeFunction(NativeFunctions::Match) => true, + CostExprNode::UserFunction(name) => is_branching_function(name), + _ => false, + } +} + +/// string cost based on length +fn string_cost(length: usize) -> StaticCost { + let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); + let execution_cost = ExecutionCost::runtime(cost); + StaticCost { + min: execution_cost.clone(), + max: execution_cost, + } +} + +/// Strings are the only Value's with costs associated +pub(crate) fn calculate_value_cost(value: &Value) -> Result { + match value { + Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { + Ok(string_cost(data.data.len())) + } + Value::Sequence(SequenceData::String(CharType::ASCII(data))) => { + Ok(string_cost(data.data.len())) + } + _ => Ok(StaticCost::ZERO), + } +} + +pub(crate) fn calculate_function_cost_from_native_function( + native_function: NativeFunctions, + arg_count: u64, + args: &[SymbolicExpression], + epoch: StacksEpochId, +) -> Result { + // Derive clarity_version from epoch for lookup_reserved_functions + let clarity_version = ClarityVersion::default_for_epoch(epoch); + match lookup_reserved_functions(native_function.to_string().as_str(), &clarity_version) { + Some(CallableType::NativeFunction(_, _, cost_fn)) => { + let cost = cost_fn + .eval_for_epoch(arg_count, epoch) + .map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) + } + Some(CallableType::NativeFunction205(_, _, cost_fn, _)) => { + let cost = cost_fn + .eval_for_epoch(arg_count, epoch) + .map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) + } + Some(CallableType::SpecialFunction(_, _)) => { + let cost = special_costs::get_cost_for_special_function(native_function, args, epoch); + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) + } + Some(CallableType::UserFunction(_)) => Ok(StaticCost::ZERO), // TODO ? + None => Ok(StaticCost::ZERO), + } +} + +/// total cost handling branching +pub(crate) fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); + + for child in &node.children { + let child_summing = calculate_total_cost_with_summing(child); + summing_cost.add_summing(&child_summing); + } + + summing_cost +} + +pub(crate) fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::new(); + + // Check if this is a branching function by examining the node's expression + let is_branching = is_node_branching(node); + + if is_branching { + match &node.expr { + CostExprNode::NativeFunction(NativeFunctions::If) + | CostExprNode::NativeFunction(NativeFunctions::Match) => { + // TODO match? + if node.children.len() >= 2 { + let condition_cost = calculate_total_cost_with_summing(&node.children[0]); + let condition_total = condition_cost.add_all(); + + // Add the root cost + condition cost to each branch + let mut root_and_condition = node.cost.min.clone(); + let _ = root_and_condition.add(&condition_total); + + for child_cost_node in node.children.iter().skip(1) { + let branch_cost = calculate_total_cost_with_summing(child_cost_node); + let branch_total = branch_cost.add_all(); + + let mut path_cost = root_and_condition.clone(); + let _ = path_cost.add(&branch_total); + + summing_cost.add_cost(path_cost); + } + } + } + _ => { + // For other branching functions, fall back to sequential processing + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); + } + } + } else { + // For non-branching, add all costs sequentially + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); + } + + summing_cost +} + +impl From for StaticCost { + fn from(summing: SummingExecutionCost) -> Self { + StaticCost { + min: summing.min(), + max: summing.max(), + } + } +} + +/// get min & max costs for a given cost function +fn get_costs( + cost_fn: fn(u64) -> Result, + arg_count: u64, +) -> Result { + let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(cost) +} diff --git a/clarity/src/vm/ast/static_cost/trait_counter.rs b/clarity/src/vm/ast/static_cost/trait_counter.rs new file mode 100644 index 0000000000..7cd615fd76 --- /dev/null +++ b/clarity/src/vm/ast/static_cost/trait_counter.rs @@ -0,0 +1,418 @@ +use std::collections::HashMap; + +use clarity_types::representations::ClarityName; +use clarity_types::types::Value; + +use crate::vm::ast::static_cost::{CostAnalysisNode, CostExprNode}; +use crate::vm::costs::analysis::is_function_definition; +use crate::vm::functions::NativeFunctions; +use crate::vm::representations::{SymbolicExpression, SymbolicExpressionType}; + +type MinMaxTraitCount = (u64, u64); +pub type TraitCount = HashMap; + +/// Context passed to visitors during trait count analysis +pub struct TraitCountContext { + containing_fn_name: String, + multiplier: (u64, u64), +} + +impl TraitCountContext { + pub fn new(containing_fn_name: String, multiplier: (u64, u64)) -> Self { + Self { + containing_fn_name, + multiplier, + } + } + + fn with_multiplier(&self, multiplier: (u64, u64)) -> Self { + Self { + containing_fn_name: self.containing_fn_name.clone(), + multiplier, + } + } + + fn with_fn_name(&self, fn_name: String) -> Self { + Self { + containing_fn_name: fn_name, + multiplier: self.multiplier, + } + } +} + +/// Extract the list size multiplier from a list expression (for map/filter/fold operations) +/// Expects a list in the form `(list )` where size is an integer literal +fn extract_list_multiplier(list: &[SymbolicExpression]) -> (u64, u64) { + if list.is_empty() { + return (1, 1); + } + + let is_list_atom = list[0] + .match_atom() + .map(|a| a.as_str() == "list") + .unwrap_or(false); + if !is_list_atom || list.len() < 2 { + return (1, 1); + } + + match &list[1].expr { + SymbolicExpressionType::LiteralValue(Value::Int(value)) => (0, *value as u64), + _ => (1, 1), + } +} + +/// Increment trait count for a function +fn increment_trait_count(trait_counts: &mut TraitCount, fn_name: &str, multiplier: (u64, u64)) { + trait_counts + .entry(fn_name.to_string()) + .and_modify(|(min, max)| { + *min += multiplier.0; + *max += multiplier.1; + }) + .or_insert(multiplier); +} + +/// Propagate trait count from one function to another with a multiplier +fn propagate_trait_count( + trait_counts: &mut TraitCount, + from_fn: &str, + to_fn: &str, + multiplier: (u64, u64), +) { + if let Some(called_trait_count) = trait_counts.get(from_fn).cloned() { + trait_counts + .entry(to_fn.to_string()) + .and_modify(|(min, max)| { + *min += called_trait_count.0 * multiplier.0; + *max += called_trait_count.1 * multiplier.1; + }) + .or_insert(( + called_trait_count.0 * multiplier.0, + called_trait_count.1 * multiplier.1, + )); + } +} + +/// Visitor trait for traversing cost analysis nodes and collecting/propagating trait counts +pub trait TraitCountVisitor { + fn visit_user_argument( + &mut self, + node: &CostAnalysisNode, + arg_name: &ClarityName, + arg_type: &SymbolicExpressionType, + context: &TraitCountContext, + ); + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ); + fn visit_atom_value(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); + fn visit_atom( + &mut self, + node: &CostAnalysisNode, + atom: &ClarityName, + context: &TraitCountContext, + ); + fn visit_field_identifier(&mut self, node: &CostAnalysisNode, context: &TraitCountContext); + fn visit_trait_reference( + &mut self, + node: &CostAnalysisNode, + trait_name: &ClarityName, + context: &TraitCountContext, + ); + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ); + + fn visit(&mut self, node: &CostAnalysisNode, context: &TraitCountContext) { + match &node.expr { + CostExprNode::UserArgument(arg_name, arg_type) => { + self.visit_user_argument(node, arg_name, arg_type, context); + } + CostExprNode::NativeFunction(native_function) => { + self.visit_native_function(node, native_function, context); + } + CostExprNode::AtomValue(_atom_value) => { + self.visit_atom_value(node, context); + } + CostExprNode::Atom(atom) => { + self.visit_atom(node, atom, context); + } + CostExprNode::FieldIdentifier(_field_identifier) => { + self.visit_field_identifier(node, context); + } + CostExprNode::TraitReference(trait_name) => { + self.visit_trait_reference(node, trait_name, context); + } + CostExprNode::UserFunction(user_function) => { + self.visit_user_function(node, user_function, context); + } + } + } +} + +pub struct TraitCountCollector { + pub trait_counts: TraitCount, + pub trait_names: HashMap, +} + +impl TraitCountCollector { + pub fn new() -> Self { + Self { + trait_counts: HashMap::new(), + trait_names: HashMap::new(), + } + } +} + +impl TraitCountVisitor for TraitCountCollector { + fn visit_user_argument( + &mut self, + _node: &CostAnalysisNode, + arg_name: &ClarityName, + arg_type: &SymbolicExpressionType, + _context: &TraitCountContext, + ) { + if let SymbolicExpressionType::TraitReference(name, _) = arg_type { + self.trait_names + .insert(arg_name.clone(), name.clone().to_string()); + } + } + + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ) { + match native_function { + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + if node.children.len() > 1 { + let list_node = &node.children[1]; + let multiplier = + if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = + &list_node.expr + { + extract_list_multiplier(list) + } else { + (1, 1) + }; + let new_context = context.with_multiplier(multiplier); + for child in &node.children { + self.visit(child, &new_context); + } + } + } + _ => { + for child in &node.children { + self.visit(child, context); + } + } + } + } + + fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { + // No action needed for atom values + } + + fn visit_atom( + &mut self, + _node: &CostAnalysisNode, + atom: &ClarityName, + context: &TraitCountContext, + ) { + if self.trait_names.contains_key(atom) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + } + + fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) { + // No action needed for field identifiers + } + + fn visit_trait_reference( + &mut self, + _node: &CostAnalysisNode, + _trait_name: &ClarityName, + context: &TraitCountContext, + ) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ) { + // Check if this is a trait call (the function name is a trait argument) + if self.trait_names.contains_key(user_function) { + increment_trait_count( + &mut self.trait_counts, + &context.containing_fn_name, + context.multiplier, + ); + } + + // Determine the containing function name for children + let fn_name = if is_function_definition(user_function.as_str()) { + context.containing_fn_name.clone() + } else { + user_function.to_string() + }; + let child_context = context.with_fn_name(fn_name); + + for child in &node.children { + self.visit(child, &child_context); + } + } +} + +/// Second pass visitor: propagates trait counts through function calls +pub struct TraitCountPropagator<'a> { + trait_counts: &'a mut TraitCount, + trait_names: &'a HashMap, +} + +impl<'a> TraitCountPropagator<'a> { + pub fn new( + trait_counts: &'a mut TraitCount, + trait_names: &'a HashMap, + ) -> Self { + Self { + trait_counts, + trait_names, + } + } +} + +impl<'a> TraitCountVisitor for TraitCountPropagator<'a> { + fn visit_user_argument( + &mut self, + _node: &CostAnalysisNode, + _arg_name: &ClarityName, + _arg_type: &SymbolicExpressionType, + _context: &TraitCountContext, + ) { + // No propagation needed for arguments + } + + fn visit_native_function( + &mut self, + node: &CostAnalysisNode, + native_function: &NativeFunctions, + context: &TraitCountContext, + ) { + match native_function { + NativeFunctions::Map | NativeFunctions::Filter | NativeFunctions::Fold => { + if node.children.len() > 1 { + let list_node = &node.children[1]; + let multiplier = + if let CostExprNode::UserArgument(_, SymbolicExpressionType::List(list)) = + &list_node.expr + { + extract_list_multiplier(list) + } else { + (1, 1) + }; + + // Process the function being called in map/filter/fold + let mut skip_first_child = false; + if let Some(function_node) = node.children.get(0) { + if let CostExprNode::UserFunction(function_name) = &function_node.expr { + if !self.trait_names.contains_key(function_name) { + // This is a regular function call, not a trait call + propagate_trait_count( + self.trait_counts, + &function_name.to_string(), + &context.containing_fn_name, + multiplier, + ); + skip_first_child = true; + } + } + } + + // Continue traversing children, but skip the function node if we already propagated it + for (idx, child) in node.children.iter().enumerate() { + if idx == 0 && skip_first_child { + continue; + } + let new_context = context.with_multiplier(multiplier); + self.visit(child, &new_context); + } + } + } + _ => { + for child in &node.children { + self.visit(child, context); + } + } + } + } + + fn visit_atom_value(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} + + fn visit_atom( + &mut self, + _node: &CostAnalysisNode, + _atom: &ClarityName, + _context: &TraitCountContext, + ) { + } + + fn visit_field_identifier(&mut self, _node: &CostAnalysisNode, _context: &TraitCountContext) {} + + fn visit_trait_reference( + &mut self, + _node: &CostAnalysisNode, + _trait_name: &ClarityName, + _context: &TraitCountContext, + ) { + // No propagation needed for trait references (already counted in first pass) + } + + fn visit_user_function( + &mut self, + node: &CostAnalysisNode, + user_function: &ClarityName, + context: &TraitCountContext, + ) { + if !is_function_definition(user_function.as_str()) + && !self.trait_names.contains_key(user_function) + { + // This is a regular function call, not a trait call or function definition + propagate_trait_count( + self.trait_counts, + &user_function.to_string(), + &context.containing_fn_name, + context.multiplier, + ); + } + + // Determine the containing function name for children + let fn_name = if is_function_definition(user_function.as_str()) { + context.containing_fn_name.clone() + } else { + user_function.to_string() + }; + let child_context = context.with_fn_name(fn_name); + + for child in &node.children { + self.visit(child, &child_context); + } + } +} diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs new file mode 100644 index 0000000000..06b04c12d1 --- /dev/null +++ b/clarity/src/vm/costs/analysis.rs @@ -0,0 +1,956 @@ +// Static cost analysis for Clarity contracts + +use std::collections::HashMap; + +use clarity_types::types::TraitIdentifier; +use stacks_common::types::StacksEpochId; + +use crate::vm::ast::build_ast; +// #[cfg(feature = "developer-mode")] +use crate::vm::ast::static_cost::{ + calculate_function_cost, calculate_function_cost_from_native_function, + calculate_total_cost_with_branching, calculate_value_cost, TraitCount, TraitCountCollector, + TraitCountContext, TraitCountPropagator, TraitCountVisitor, +}; +use crate::vm::contexts::Environment; +use crate::vm::costs::ExecutionCost; +use crate::vm::functions::NativeFunctions; +use crate::vm::representations::{ClarityName, SymbolicExpression, SymbolicExpressionType}; +use crate::vm::types::QualifiedContractIdentifier; +use crate::vm::{ClarityVersion, Value}; +// TODO: +// contract-call? - get source from database +// type-checking +// lookups +// unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) +// split up trait counting and expr node tree impl into separate module? + +const STRING_COST_BASE: u64 = 36; +const STRING_COST_MULTIPLIER: u64 = 3; + +/// Functions where string arguments have zero cost because the function +/// cost includes their processing +const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; + +const FUNCTION_DEFINITION_KEYWORDS: &[&str] = + &["define-public", "define-private", "define-read-only"]; + +pub(crate) fn is_function_definition(function_name: &str) -> bool { + FUNCTION_DEFINITION_KEYWORDS.contains(&function_name) +} + +#[derive(Debug, Clone)] +pub enum CostExprNode { + // Native Clarity functions + NativeFunction(NativeFunctions), + // Non-native expressions + AtomValue(Value), + Atom(ClarityName), + FieldIdentifier(TraitIdentifier), + TraitReference(ClarityName), + // User function arguments + UserArgument(ClarityName, SymbolicExpressionType), // (argument_name, argument_type) + // User-defined functions + UserFunction(ClarityName), +} + +#[derive(Debug, Clone)] +pub struct CostAnalysisNode { + pub expr: CostExprNode, + pub cost: StaticCost, + pub children: Vec, +} + +impl CostAnalysisNode { + pub fn new(expr: CostExprNode, cost: StaticCost, children: Vec) -> Self { + Self { + expr, + cost, + children, + } + } + + pub fn leaf(expr: CostExprNode, cost: StaticCost) -> Self { + Self { + expr, + cost, + children: vec![], + } + } +} + +#[derive(Debug, Clone)] +pub struct StaticCost { + pub min: ExecutionCost, + pub max: ExecutionCost, +} + +impl StaticCost { + pub const ZERO: StaticCost = StaticCost { + min: ExecutionCost::ZERO, + max: ExecutionCost::ZERO, + }; +} + +#[derive(Debug, Clone)] +pub struct UserArgumentsContext { + /// Map from argument name to argument type + pub arguments: HashMap, +} + +impl UserArgumentsContext { + pub fn new() -> Self { + Self { + arguments: HashMap::new(), + } + } + + pub fn add_argument(&mut self, name: ClarityName, arg_type: SymbolicExpressionType) { + self.arguments.insert(name, arg_type); + } + + pub fn is_user_argument(&self, name: &ClarityName) -> bool { + self.arguments.contains_key(name) + } + + pub fn get_argument_type(&self, name: &ClarityName) -> Option<&SymbolicExpressionType> { + self.arguments.get(name) + } +} + +/// A type to track summed execution costs for different paths +/// This allows us to compute min and max costs across different execution paths +#[derive(Debug, Clone)] +pub struct SummingExecutionCost { + pub costs: Vec, +} + +impl SummingExecutionCost { + pub fn new() -> Self { + Self { costs: Vec::new() } + } + + pub fn from_single(cost: ExecutionCost) -> Self { + Self { costs: vec![cost] } + } + + pub fn add_cost(&mut self, cost: ExecutionCost) { + self.costs.push(cost); + } + + pub fn add_summing(&mut self, other: &SummingExecutionCost) { + self.costs.extend(other.costs.clone()); + } + + /// minimum cost across all paths + pub fn min(&self) -> ExecutionCost { + if self.costs.is_empty() { + ExecutionCost::ZERO + } else { + self.costs + .iter() + .fold(self.costs[0].clone(), |acc, cost| ExecutionCost { + runtime: acc.runtime.min(cost.runtime), + write_length: acc.write_length.min(cost.write_length), + write_count: acc.write_count.min(cost.write_count), + read_length: acc.read_length.min(cost.read_length), + read_count: acc.read_count.min(cost.read_count), + }) + } + } + + /// maximum cost across all paths + pub fn max(&self) -> ExecutionCost { + if self.costs.is_empty() { + ExecutionCost::ZERO + } else { + self.costs + .iter() + .fold(self.costs[0].clone(), |acc, cost| ExecutionCost { + runtime: acc.runtime.max(cost.runtime), + write_length: acc.write_length.max(cost.write_length), + write_count: acc.write_count.max(cost.write_count), + read_length: acc.read_length.max(cost.read_length), + read_count: acc.read_count.max(cost.read_count), + }) + } + } + + pub fn add_all(&self) -> ExecutionCost { + self.costs + .iter() + .fold(ExecutionCost::ZERO, |mut acc, cost| { + let _ = acc.add(cost); + acc + }) + } +} + +fn make_ast( + source: &str, + epoch: StacksEpochId, + clarity_version: &ClarityVersion, +) -> Result { + let contract_identifier = QualifiedContractIdentifier::transient(); + let mut cost_tracker = (); + let ast = build_ast( + &contract_identifier, + source, + &mut cost_tracker, + *clarity_version, + epoch, + ) + .map_err(|e| format!("Parse error: {:?}", e))?; + Ok(ast) +} + +/// STatic execution cost for functions within Environment +/// returns the top level cost for specific functions +/// {some-function-name: (CostAnalysisNode, Some({some-function-name: (1,1)}))} +pub fn static_cost( + env: &mut Environment, + contract_identifier: &QualifiedContractIdentifier, +) -> Result)>, String> { + let contract_source = env + .global_context + .database + .get_contract_src(contract_identifier) + .ok_or_else(|| { + format!( + "Contract source ({:?}) not found in database", + contract_identifier.to_string(), + ) + })?; + + let contract = env + .global_context + .database + .get_contract(contract_identifier) + .map_err(|e| format!("Failed to get contract: {:?}", e))?; + + let clarity_version = contract.contract_context.get_clarity_version(); + + let epoch = env.global_context.epoch_id; + let ast = make_ast(&contract_source, epoch, clarity_version)?; + + static_cost_tree_from_ast(&ast, clarity_version, epoch) +} + +/// same idea as `static_cost` but returns the root of the cost analysis tree for each function +/// Useful if you need to analyze specific nodes in the cost tree +pub fn static_cost_tree( + env: &mut Environment, + contract_identifier: &QualifiedContractIdentifier, +) -> Result)>, String> { + let contract_source = env + .global_context + .database + .get_contract_src(contract_identifier) + .ok_or_else(|| { + format!( + "Contract source ({:?}) not found in database", + contract_identifier.to_string(), + ) + })?; + + let contract = env + .global_context + .database + .get_contract(contract_identifier) + .map_err(|e| format!("Failed to get contract: {:?}", e))?; + + let clarity_version = contract.contract_context.get_clarity_version(); + + let epoch = env.global_context.epoch_id; + let ast = make_ast(&contract_source, epoch, clarity_version)?; + + static_cost_tree_from_ast(&ast, clarity_version, epoch) +} + +pub fn static_cost_from_ast( + contract_ast: &crate::vm::ast::ContractAST, + clarity_version: &ClarityVersion, + epoch: StacksEpochId, +) -> Result)>, String> { + let cost_trees_with_traits = static_cost_tree_from_ast(contract_ast, clarity_version, epoch)?; + + // Extract trait_count from the first entry (all entries have the same trait_count) + let trait_count = cost_trees_with_traits + .values() + .next() + .and_then(|(_, trait_count)| trait_count.clone()); + + // Convert CostAnalysisNode to StaticCost + let costs: HashMap = cost_trees_with_traits + .into_iter() + .map(|(name, (cost_analysis_node, _))| { + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_node); + (name, summing_cost.into()) + }) + .collect(); + + Ok(costs + .into_iter() + .map(|(name, cost)| (name, (cost, trait_count.clone()))) + .collect()) +} + +pub(crate) fn static_cost_tree_from_ast( + ast: &crate::vm::ast::ContractAST, + clarity_version: &ClarityVersion, + epoch: StacksEpochId, +) -> Result)>, String> { + let exprs = &ast.expressions; + let user_args = UserArgumentsContext::new(); + let costs_map: HashMap> = HashMap::new(); + let mut costs: HashMap> = HashMap::new(); + // first pass extracts the function names + for expr in exprs { + if let Some(function_name) = extract_function_name(expr) { + costs.insert(function_name, None); + } + } + // second pass computes the cost + for expr in exprs { + if let Some(function_name) = extract_function_name(expr) { + let (_, cost_analysis_tree) = + build_cost_analysis_tree(expr, &user_args, &costs_map, clarity_version, epoch)?; + costs.insert(function_name, Some(cost_analysis_tree)); + } + } + + // Build the final map with cost analysis nodes + let cost_trees: HashMap = costs + .into_iter() + .filter_map(|(name, cost)| cost.map(|c| (name, c))) + .collect(); + + // Compute trait_count while creating the root CostAnalysisNode + let trait_count = get_trait_count(&cost_trees); + + // Return each node with its trait_count + Ok(cost_trees + .into_iter() + .map(|(name, node)| (name, (node, trait_count.clone()))) + .collect()) +} + +/// Extract function name from a symbolic expression +fn extract_function_name(expr: &SymbolicExpression) -> Option { + expr.match_list().and_then(|list| { + list.first() + .and_then(|first| first.match_atom()) + .filter(|atom| is_function_definition(atom.as_str())) + .and_then(|_| list.get(1)) + .and_then(|sig| sig.match_list()) + .and_then(|signature| signature.first()) + .and_then(|name| name.match_atom()) + .map(|name| name.to_string()) + }) +} + +pub fn build_cost_analysis_tree( + expr: &SymbolicExpression, + user_args: &UserArgumentsContext, + cost_map: &HashMap>, + clarity_version: &ClarityVersion, + epoch: StacksEpochId, +) -> Result<(Option, CostAnalysisNode), String> { + match &expr.expr { + SymbolicExpressionType::List(list) => { + if let Some(function_name) = list.first().and_then(|first| first.match_atom()) { + if is_function_definition(function_name.as_str()) { + let (returned_function_name, cost_analysis_tree) = + build_function_definition_cost_analysis_tree( + list, + user_args, + cost_map, + clarity_version, + epoch, + )?; + Ok((Some(returned_function_name), cost_analysis_tree)) + } else { + let cost_analysis_tree = build_listlike_cost_analysis_tree( + list, + user_args, + cost_map, + clarity_version, + epoch, + )?; + Ok((None, cost_analysis_tree)) + } + } else { + let cost_analysis_tree = build_listlike_cost_analysis_tree( + list, + user_args, + cost_map, + clarity_version, + epoch, + )?; + Ok((None, cost_analysis_tree)) + } + } + SymbolicExpressionType::AtomValue(value) => { + let cost = calculate_value_cost(value)?; + Ok(( + None, + CostAnalysisNode::leaf(CostExprNode::AtomValue(value.clone()), cost), + )) + } + SymbolicExpressionType::LiteralValue(value) => { + let cost = calculate_value_cost(value)?; + Ok(( + None, + CostAnalysisNode::leaf(CostExprNode::AtomValue(value.clone()), cost), + )) + } + SymbolicExpressionType::Atom(name) => { + let expr_node = parse_atom_expression(name, user_args)?; + Ok((None, CostAnalysisNode::leaf(expr_node, StaticCost::ZERO))) + } + SymbolicExpressionType::Field(field_identifier) => Ok(( + None, + CostAnalysisNode::leaf( + CostExprNode::FieldIdentifier(field_identifier.clone()), + StaticCost::ZERO, + ), + )), + SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => Ok(( + None, + CostAnalysisNode::leaf( + CostExprNode::TraitReference(trait_name.clone()), + StaticCost::ZERO, + ), + )), + } +} + +/// Parse an atom expression into an ExprNode +fn parse_atom_expression( + name: &ClarityName, + user_args: &UserArgumentsContext, +) -> Result { + // Check if this atom is a user-defined function argument + if user_args.is_user_argument(name) { + if let Some(arg_type) = user_args.get_argument_type(name) { + Ok(CostExprNode::UserArgument(name.clone(), arg_type.clone())) + } else { + Ok(CostExprNode::Atom(name.clone())) + } + } else { + Ok(CostExprNode::Atom(name.clone())) + } +} + +/// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) +fn build_function_definition_cost_analysis_tree( + list: &[SymbolicExpression], + _user_args: &UserArgumentsContext, + cost_map: &HashMap>, + clarity_version: &ClarityVersion, + epoch: StacksEpochId, +) -> Result<(String, CostAnalysisNode), String> { + let define_type = list[0] + .match_atom() + .ok_or("Expected atom for define type")?; + let signature = list[1] + .match_list() + .ok_or("Expected list for function signature")?; + println!("signature: {:?}", signature); + let body = &list[2]; + + let mut children = Vec::new(); + let mut function_user_args = UserArgumentsContext::new(); + + // Process function arguments: (a u64) + for arg_expr in signature.iter().skip(1) { + if let Some(arg_list) = arg_expr.match_list() { + if arg_list.len() == 2 { + let arg_name = arg_list[0] + .match_atom() + .ok_or("Expected atom for argument name")?; + + let arg_type = arg_list[1].clone(); + + // Add to function's user arguments context + function_user_args.add_argument(arg_name.clone(), arg_type.clone().expr); + + // Create UserArgument node + children.push(CostAnalysisNode::leaf( + CostExprNode::UserArgument(arg_name.clone(), arg_type.clone().expr), + StaticCost::ZERO, + )); + } + } + } + + // Process the function body with the function's user arguments context + let (_, body_tree) = + build_cost_analysis_tree(body, &function_user_args, cost_map, clarity_version, epoch)?; + children.push(body_tree); + + // Get the function name from the signature + let function_name = signature[0] + .match_atom() + .ok_or("Expected atom for function name")?; + + // Create the function definition node with zero cost (function definitions themselves don't have execution cost) + Ok(( + function_name.clone().to_string(), + CostAnalysisNode::new( + CostExprNode::UserFunction(define_type.clone()), + StaticCost::ZERO, + children, + ), + )) +} + +fn get_function_name(expr: &SymbolicExpression) -> Result { + match &expr.expr { + SymbolicExpressionType::Atom(name) => Ok(name.clone()), + _ => Err("First element must be an atom (function name)".to_string()), + } +} + +/// Helper function to build expression trees for both lists and tuples +fn build_listlike_cost_analysis_tree( + exprs: &[SymbolicExpression], + user_args: &UserArgumentsContext, + cost_map: &HashMap>, + clarity_version: &ClarityVersion, + epoch: StacksEpochId, +) -> Result { + let mut children = Vec::new(); + + // Build children for all exprs + for expr in exprs[1..].iter() { + let (_, child_tree) = + build_cost_analysis_tree(expr, user_args, cost_map, clarity_version, epoch)?; + children.push(child_tree); + } + + let (expr_node, cost) = match &exprs[0].expr { + SymbolicExpressionType::List(_) => { + // Recursively analyze the nested list structure + let (_, nested_tree) = + build_cost_analysis_tree(&exprs[0], user_args, cost_map, clarity_version, epoch)?; + // Add the nested tree as a child (its cost will be included when summing children) + children.insert(0, nested_tree); + // The root cost is zero - the actual cost comes from the nested expression + let expr_node = CostExprNode::Atom(ClarityName::from("nested-expression")); + (expr_node, StaticCost::ZERO) + } + SymbolicExpressionType::Atom(name) => { + // Try to get function name from first element + // lookup the function as a native function first + // special functions + // - let, etc use bindings lengths not argument lengths + if let Some(native_function) = + NativeFunctions::lookup_by_name_at_version(name.as_str(), clarity_version) + { + let cost = calculate_function_cost_from_native_function( + native_function, + children.len() as u64, + &exprs[1..], + epoch, + )?; + + (CostExprNode::NativeFunction(native_function), cost) + } else { + // If not a native function, treat as user-defined function and look it up + let expr_node = CostExprNode::UserFunction(name.clone()); + let cost = calculate_function_cost(name.to_string(), cost_map, clarity_version)?; + (expr_node, cost) + } + } + SymbolicExpressionType::AtomValue(value) => { + // It's an atom value - calculate its cost + let cost = calculate_value_cost(value)?; + (CostExprNode::AtomValue(value.clone()), cost) + } + SymbolicExpressionType::TraitReference(trait_name, _trait_definition) => ( + CostExprNode::TraitReference(trait_name.clone()), + StaticCost::ZERO, + ), + SymbolicExpressionType::Field(field_identifier) => ( + CostExprNode::FieldIdentifier(field_identifier.clone()), + StaticCost::ZERO, + ), + SymbolicExpressionType::LiteralValue(value) => { + let cost = calculate_value_cost(value)?; + // TODO not sure if LiteralValue is needed in the CostExprNode types + (CostExprNode::AtomValue(value.clone()), cost) + } + }; + + Ok(CostAnalysisNode::new(expr_node, cost, children)) +} + +pub(crate) fn get_trait_count(costs: &HashMap) -> Option { + // First pass: collect trait counts and trait names + let mut collector = TraitCountCollector::new(); + for (name, cost_analysis_node) in costs.iter() { + let context = TraitCountContext::new(name.clone(), (1, 1)); + collector.visit(cost_analysis_node, &context); + } + + // Second pass: propagate trait counts through function calls + // If function A calls function B and uses a map, filter, or fold with + // traits, the maximum will reflect that in A's trait call counts + let mut propagator = + TraitCountPropagator::new(&mut collector.trait_counts, &collector.trait_names); + for (name, cost_analysis_node) in costs.iter() { + let context = TraitCountContext::new(name.clone(), (1, 1)); + propagator.visit(cost_analysis_node, &context); + } + + Some(collector.trait_counts) +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::vm::ast::static_cost::is_node_branching; + + fn static_cost_native_test( + source: &str, + clarity_version: &ClarityVersion, + ) -> Result { + let cost_map: HashMap> = HashMap::new(); + + let epoch = StacksEpochId::latest(); // XXX this should be matched with the clarity version + let ast = make_ast(source, epoch, clarity_version)?; + let exprs = &ast.expressions; + let user_args = UserArgumentsContext::new(); + let expr = &exprs[0]; + let (_, cost_analysis_tree) = + build_cost_analysis_tree(&expr, &user_args, &cost_map, clarity_version, epoch)?; + + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); + Ok(summing_cost.into()) + } + + fn static_cost_test( + source: &str, + clarity_version: &ClarityVersion, + ) -> Result, String> { + let epoch = StacksEpochId::latest(); + let ast = make_ast(source, epoch, clarity_version)?; + let costs = static_cost_from_ast(&ast, clarity_version, epoch)?; + Ok(costs + .into_iter() + .map(|(name, (cost, _trait_count))| (name, cost)) + .collect()) + } + + fn build_test_ast(src: &str) -> crate::vm::ast::ContractAST { + let contract_identifier = QualifiedContractIdentifier::transient(); + let mut cost_tracker = (); + let ast = build_ast( + &contract_identifier, + src, + &mut cost_tracker, + ClarityVersion::Clarity3, + StacksEpochId::latest(), + ) + .unwrap(); + ast + } + + #[test] + fn test_constant() { + let source = "9001"; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + assert_eq!(cost.min.runtime, 0); + assert_eq!(cost.max.runtime, 0); + } + + #[test] + fn test_simple_addition() { + let source = "(+ 1 2)"; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + + // Min: linear(2, 11, 125) = 11*2 + 125 = 147 + assert_eq!(cost.min.runtime, 147); + assert_eq!(cost.max.runtime, 147); + } + + #[test] + fn test_arithmetic() { + let source = "(- u4 (+ u1 u2))"; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + assert_eq!(cost.min.runtime, 147 + 147); + assert_eq!(cost.max.runtime, 147 + 147); + } + + #[test] + fn test_nested_operations() { + let source = "(* (+ u1 u2) (- u3 u4))"; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + // multiplication: 13*2 + 125 = 151 + assert_eq!(cost.min.runtime, 151 + 147 + 147); + assert_eq!(cost.max.runtime, 151 + 147 + 147); + } + + #[test] + fn test_string_concat_min_max() { + let source = r#"(concat "hello" "world")"#; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + + assert_eq!(cost.min.runtime, 366); + assert_eq!(cost.max.runtime, 366); + } + + #[test] + fn test_string_len_min_max() { + let source = r#"(len "hello")"#; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + + assert_eq!(cost.min.runtime, 612); + assert_eq!(cost.max.runtime, 612); + } + + #[test] + fn test_branching() { + let source = "(if (> 3 0) (ok (concat \"hello\" \"world\")) (ok \"asdf\"))"; + let cost = static_cost_native_test(source, &ClarityVersion::Clarity3).unwrap(); + // min: raw string + // max: concat + + assert_eq!(cost.min.runtime, 346); + assert_eq!(cost.max.runtime, 565); + } + + // ---- ExprTreee building specific tests + #[test] + fn test_build_cost_analysis_tree_if_expression() { + let src = "(if (> 3 0) (ok true) (ok false))"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + + // Root should be an If node + assert!(matches!( + cost_tree.expr, + CostExprNode::NativeFunction(NativeFunctions::If) + )); + assert!(is_node_branching(&cost_tree)); + assert_eq!(cost_tree.children.len(), 3); + + let gt_node = &cost_tree.children[0]; + assert!(matches!( + gt_node.expr, + CostExprNode::NativeFunction(NativeFunctions::CmpGreater) + )); + assert_eq!(gt_node.children.len(), 2); + + // The comparison node has 3 children: the function name, left operand, right operand + let left_val = >_node.children[0]; + let right_val = >_node.children[1]; + assert!(matches!(left_val.expr, CostExprNode::AtomValue(_))); + assert!(matches!(right_val.expr, CostExprNode::AtomValue(_))); + + let ok_true_node = &cost_tree.children[2]; + assert!(matches!( + ok_true_node.expr, + CostExprNode::NativeFunction(NativeFunctions::ConsOkay) + )); + assert_eq!(ok_true_node.children.len(), 1); + + let ok_false_node = &cost_tree.children[2]; + assert!(matches!( + ok_false_node.expr, + CostExprNode::NativeFunction(NativeFunctions::ConsOkay) + )); + assert_eq!(ok_false_node.children.len(), 1); + } + + #[test] + fn test_build_cost_analysis_tree_arithmetic() { + let src = "(+ (* 2 3) (- 5 1))"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + + assert!(matches!( + cost_tree.expr, + CostExprNode::NativeFunction(NativeFunctions::Add) + )); + assert!(!is_node_branching(&cost_tree)); + assert_eq!(cost_tree.children.len(), 2); + + let mul_node = &cost_tree.children[0]; + assert!(matches!( + mul_node.expr, + CostExprNode::NativeFunction(NativeFunctions::Multiply) + )); + assert_eq!(mul_node.children.len(), 2); + + let sub_node = &cost_tree.children[1]; + assert!(matches!( + sub_node.expr, + CostExprNode::NativeFunction(NativeFunctions::Subtract) + )); + assert_eq!(sub_node.children.len(), 2); + } + + #[test] + fn test_build_cost_analysis_tree_with_comments() { + let src = ";; This is a comment\n(+ 5 ;; another comment\n7)"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + + assert!(matches!( + cost_tree.expr, + CostExprNode::NativeFunction(NativeFunctions::Add) + )); + assert!(!is_node_branching(&cost_tree)); + assert_eq!(cost_tree.children.len(), 2); + + for child in &cost_tree.children { + assert!(matches!(child.expr, CostExprNode::AtomValue(_))); + } + } + + #[test] + fn test_function_with_multiple_arguments() { + let src = r#"(define-public (add-two (x uint) (y uint)) (+ x y))"#; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); // Empty cost map for tests + let epoch = StacksEpochId::Epoch32; + let (_, cost_tree) = build_cost_analysis_tree( + expr, + &user_args, + &cost_map, + &ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + + assert_eq!(cost_tree.children.len(), 3); + + // First child should be UserArgument for (x uint) + let user_arg_x = &cost_tree.children[0]; + assert!(matches!(user_arg_x.expr, CostExprNode::UserArgument(_, _))); + if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { + assert_eq!(arg_name.as_str(), "x"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); + } + + // Second child should be UserArgument for (y u64) + let user_arg_y = &cost_tree.children[1]; + assert!(matches!(user_arg_y.expr, CostExprNode::UserArgument(_, _))); + if let CostExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { + assert_eq!(arg_name.as_str(), "y"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); + } + + // Third child should be the function body (+ x y) + let body_node = &cost_tree.children[2]; + assert!(matches!( + body_node.expr, + CostExprNode::NativeFunction(NativeFunctions::Add) + )); + assert_eq!(body_node.children.len(), 2); + + // Both arguments in the body should be UserArguments + let arg_x_ref = &body_node.children[0]; + let arg_y_ref = &body_node.children[1]; + assert!(matches!(arg_x_ref.expr, CostExprNode::UserArgument(_, _))); + assert!(matches!(arg_y_ref.expr, CostExprNode::UserArgument(_, _))); + + if let CostExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { + assert_eq!(name.as_str(), "x"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); + } + if let CostExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { + assert_eq!(name.as_str(), "y"); + assert!(matches!(arg_type, SymbolicExpressionType::Atom(_))); + } + } + + #[test] + fn test_static_cost_simple_addition() { + let source = "(define-public (add (a uint) (b uint)) (+ a b))"; + let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap(); + + assert_eq!(ast_cost.len(), 1); + assert!(ast_cost.contains_key("add")); + + let add_cost = ast_cost.get("add").unwrap(); + assert!(add_cost.min.runtime > 0); + assert!(add_cost.max.runtime > 0); + } + + #[test] + fn test_static_cost_multiple_functions() { + let source = r#" + (define-public (func1 (x uint)) (+ x 1)) + (define-private (func2 (y uint)) (* y 2)) + "#; + let ast_cost = static_cost_test(source, &ClarityVersion::Clarity3).unwrap(); + + assert_eq!(ast_cost.len(), 2); + + assert!(ast_cost.contains_key("func1")); + assert!(ast_cost.contains_key("func2")); + + let func1_cost = ast_cost.get("func1").unwrap(); + let func2_cost = ast_cost.get("func2").unwrap(); + assert!(func1_cost.min.runtime > 0); + assert!(func2_cost.min.runtime > 0); + } + + #[test] + fn test_extract_function_name_define_public() { + let src = "(define-public (my-func (x uint)) (ok x))"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let result = extract_function_name(expr); + assert_eq!(result, Some("my-func".to_string())); + } + + #[test] + fn test_extract_function_name_function_call_not_definition() { + // function call (not a definition) should return None + let src = "(my-func arg1 arg2)"; + let ast = build_test_ast(src); + let expr = &ast.expressions[0]; + let result = extract_function_name(expr); + assert_eq!(result, None); + } +} diff --git a/clarity/src/vm/costs/cost_functions.rs b/clarity/src/vm/costs/cost_functions.rs index 9621d9cc8b..c6137ac598 100644 --- a/clarity/src/vm/costs/cost_functions.rs +++ b/clarity/src/vm/costs/cost_functions.rs @@ -14,6 +14,11 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . use super::ExecutionCost; +use super::costs_1::Costs1; +use super::costs_2::Costs2; +use super::costs_3::Costs3; +use super::costs_4::Costs4; +use stacks_common::types::StacksEpochId; use crate::vm::errors::{RuntimeError, VmExecutionError}; define_named_enum!(ClarityCostFunction { @@ -342,6 +347,31 @@ pub trait CostValues { } impl ClarityCostFunction { + /// shortcut to eval() + pub fn eval_for_epoch( + &self, + n: u64, + epoch: StacksEpochId, + ) -> Result { + match epoch { + StacksEpochId::Epoch20 => self.eval::(n), + StacksEpochId::Epoch2_05 => self.eval::(n), + StacksEpochId::Epoch21 + | StacksEpochId::Epoch22 + | StacksEpochId::Epoch23 + | StacksEpochId::Epoch24 + | StacksEpochId::Epoch25 + | StacksEpochId::Epoch30 + | StacksEpochId::Epoch31 + | StacksEpochId::Epoch32 => self.eval::(n), + StacksEpochId::Epoch33 => self.eval::(n), + StacksEpochId::Epoch10 => { + // fallback to costs 1 since epoch 1 doesn't have direct cost mapping + self.eval::(n) + } + } + } + pub fn eval(&self, n: u64) -> Result { match self { ClarityCostFunction::AnalysisTypeAnnotate => C::cost_analysis_type_annotate(n), diff --git a/clarity/src/vm/costs/mod.rs b/clarity/src/vm/costs/mod.rs index 4e006c890c..dadf218462 100644 --- a/clarity/src/vm/costs/mod.rs +++ b/clarity/src/vm/costs/mod.rs @@ -42,6 +42,8 @@ use crate::vm::types::{ FunctionType, PrincipalData, QualifiedContractIdentifier, TupleData, TypeSignature, }; use crate::vm::{CallStack, ClarityName, Environment, LocalContext, SymbolicExpression, Value}; +#[cfg(feature = "developer-mode")] +pub mod analysis; pub mod constants; pub mod cost_functions; #[allow(unused_variables)] diff --git a/clarity/src/vm/functions/mod.rs b/clarity/src/vm/functions/mod.rs index bcffec2b2c..5e7681dda8 100644 --- a/clarity/src/vm/functions/mod.rs +++ b/clarity/src/vm/functions/mod.rs @@ -79,6 +79,7 @@ mod options; mod post_conditions; pub mod principals; mod sequences; +pub mod special_costs; pub mod tuples; define_versioned_named_enum_with_max!(NativeFunctions(ClarityVersion) { diff --git a/clarity/src/vm/functions/special_costs.rs b/clarity/src/vm/functions/special_costs.rs new file mode 100644 index 0000000000..4b6f95c26a --- /dev/null +++ b/clarity/src/vm/functions/special_costs.rs @@ -0,0 +1,25 @@ +use clarity_types::execution_cost::ExecutionCost; +use clarity_types::representations::SymbolicExpression; +use stacks_common::types::StacksEpochId; +use crate::vm::{costs::cost_functions::ClarityCostFunction, functions::NativeFunctions}; + +pub fn get_cost_for_special_function(native_function: NativeFunctions, args: &[SymbolicExpression], epoch: StacksEpochId) -> ExecutionCost { + match native_function { + NativeFunctions::Let => cost_binding_list_len(args, epoch), + NativeFunctions::If => cost_binding_list_len(args, epoch), + NativeFunctions::TupleCons => cost_binding_list_len(args, epoch), + _ => ExecutionCost::ZERO, + } +} + +pub fn cost_binding_list_len(args: &[SymbolicExpression], epoch: StacksEpochId) -> ExecutionCost { + let binding_len = args.get(1).and_then(|e| e.match_list()).map(|binding_list| binding_list.len() as u64).unwrap_or(0); + ClarityCostFunction::Let.eval_for_epoch(binding_len, epoch).unwrap_or_else(|_| { + ExecutionCost::ZERO + }) +} + +pub fn noop(_args: &[SymbolicExpression], _epoch: StacksEpochId) -> ExecutionCost { + ExecutionCost::ZERO +} + diff --git a/clarity/src/vm/tests/analysis.rs b/clarity/src/vm/tests/analysis.rs new file mode 100644 index 0000000000..e2d927f65c --- /dev/null +++ b/clarity/src/vm/tests/analysis.rs @@ -0,0 +1,279 @@ +// TODO: This needs work to get the dynamic vs static testing working +use std::collections::HashMap; +use std::path::Path; + +use rstest::rstest; +use stacks_common::types::StacksEpochId; + +use crate::vm::contexts::OwnedEnvironment; +use crate::vm::costs::analysis::{ + build_cost_analysis_tree, static_cost_from_ast, static_cost_tree_from_ast, UserArgumentsContext, +}; +use crate::vm::costs::ExecutionCost; +use crate::vm::types::{PrincipalData, QualifiedContractIdentifier}; +use crate::vm::{ast, ClarityVersion}; + +#[test] +fn test_build_cost_analysis_tree_function_definition() { + let src = r#"(define-public (somefunc (a uint)) + (ok (+ a 1)) +)"#; + + let contract_id = QualifiedContractIdentifier::transient(); + let ast = ast::parse( + &contract_id, + src, + ClarityVersion::Clarity3, + StacksEpochId::Epoch32, + ) + .expect("Failed to parse"); + + let expr = &ast[0]; + let user_args = UserArgumentsContext::new(); + let cost_map = HashMap::new(); + + let clarity_version = ClarityVersion::Clarity3; + let epoch = StacksEpochId::Epoch32; + let result = build_cost_analysis_tree(expr, &user_args, &cost_map, &clarity_version, epoch); + + match result { + Ok((function_name, node)) => { + assert_eq!(function_name, Some("somefunc".to_string())); + assert!(matches!( + node.expr, + crate::vm::costs::analysis::CostExprNode::UserFunction(_) + )); + } + Err(e) => { + panic!("Expected Ok result, got error: {}", e); + } + } +} + +#[test] +fn test_let_cost() { + let src = "(let ((a 1) (b 2)) (+ a b))"; + let src2 = "(let ((a 1) (b 2) (c 3)) (+ a b))"; // should compute for 3 bindings not 2 + + let contract_id = QualifiedContractIdentifier::transient(); + let epoch = StacksEpochId::Epoch32; + let ast = crate::vm::ast::build_ast( + &QualifiedContractIdentifier::transient(), + src, + &mut (), + ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3, epoch).unwrap(); + let (let_cost, _) = function_map.get("let").unwrap(); + let (let2_cost, _) = function_map.get("let2").unwrap(); + assert_ne!(let2_cost.min.runtime, let_cost.min.runtime); +} + + +#[test] +fn test_dependent_function_calls() { + let src = r#"(define-public (add-one (a uint)) + (begin + (print "somefunc") + (somefunc a) + ) +) +(define-private (somefunc (a uint)) + (ok (+ a 1)) +)"#; + + let contract_id = QualifiedContractIdentifier::transient(); + let epoch = StacksEpochId::Epoch32; + let ast = crate::vm::ast::build_ast( + &QualifiedContractIdentifier::transient(), + src, + &mut (), + ClarityVersion::Clarity3, + epoch, + ) + .unwrap(); + let function_map = static_cost_from_ast(&ast, &ClarityVersion::Clarity3, epoch).unwrap(); + + let (add_one_cost, _) = function_map.get("add-one").unwrap(); + let (somefunc_cost, _) = function_map.get("somefunc").unwrap(); + + println!("add_one_cost: {:?}", add_one_cost); + println!("add_one_cost: {:?}", somefunc_cost); + assert!(add_one_cost.min.runtime >= somefunc_cost.min.runtime); + assert!(add_one_cost.max.runtime >= somefunc_cost.max.runtime); +} + +#[test] +fn test_get_trait_count_direct() { + let src = r#"(define-trait trait-name ( + (send (uint principal) (response uint uint)) +)) +(define-public (something (trait ) (addresses (list 10 principal))) + (map (send u500 trait) addresses) +) +(define-private (send (trait ) (addr principal)) (trait addr)) +"#; + + let contract_id = QualifiedContractIdentifier::transient(); + let ast = crate::vm::ast::build_ast( + &contract_id, + src, + &mut (), + ClarityVersion::Clarity3, + StacksEpochId::Epoch32, + ) + .unwrap(); + + let costs = + static_cost_tree_from_ast(&ast, &ClarityVersion::Clarity3, StacksEpochId::Epoch32).unwrap(); + + // Extract trait_count from the result (all entries have the same trait_count) + let trait_count = costs + .values() + .next() + .and_then(|(_, trait_count)| trait_count.clone()); + + let expected = { + let mut map = HashMap::new(); + map.insert("something".to_string(), (0, 10)); + map.insert("send".to_string(), (1, 1)); + Some(map) + }; + + assert_eq!(trait_count, expected); +} + +#[rstest] +fn test_trait_counting() { + // map, fold, filter over traits counting + let src = r#"(define-trait trait-name ( + (send (uint principal) (response uint uint)) +)) +(define-public (something (trait ) (addresses (list 10 principal))) + (map (send u500 trait) addresses) +) +(define-private (send (trait ) (addr principal)) (trait addr)) +"#; + let contract_id = QualifiedContractIdentifier::local("trait-counting").unwrap(); + let epoch = StacksEpochId::Epoch32; + let ast = + crate::vm::ast::build_ast(&contract_id, src, &mut (), ClarityVersion::Clarity3, epoch) + .unwrap(); + let static_cost = static_cost_from_ast(&ast, &ClarityVersion::Clarity3, epoch) + .unwrap() + .clone(); + let send_trait_count_map = static_cost.get("send").unwrap().1.clone().unwrap(); + let send_trait_count = send_trait_count_map.get("send").unwrap(); + assert_eq!(send_trait_count.0, 1); + assert_eq!(send_trait_count.1, 1); + + let something_trait_count_map = static_cost.get("something").unwrap().1.clone().unwrap(); + let something_trait_count = something_trait_count_map.get("something").unwrap(); + assert_eq!(something_trait_count.0, 0); + assert_eq!(something_trait_count.1, 10); +} + +/// Helper function to execute a contract function and return the execution cost +fn execute_contract_function_and_get_cost( + env: &mut OwnedEnvironment, + contract_id: &QualifiedContractIdentifier, + function_name: &str, + args: &[u64], + version: ClarityVersion, +) -> ExecutionCost { + let initial_cost = env.get_cost_total(); + + let sender = PrincipalData::parse_qualified_contract_principal( + "ST1PQHQKV0RJXZFY1DGX8MNSNYVE3VGZJSRTPGZGM.sender", + ) + .unwrap(); + + let arg_str = args + .iter() + .map(|a| format!("u{}", a)) + .collect::>() + .join(" "); + let function_call = format!("({} {})", function_name, arg_str); + + let ast = crate::vm::ast::parse( + &QualifiedContractIdentifier::transient(), + &function_call, + version, + StacksEpochId::Epoch21, + ) + .expect("Failed to parse function call"); + + if !ast.is_empty() { + let _result = env.execute_transaction( + sender, + None, + contract_id.clone(), + &function_call, + &ast[0..1], + ); + } + + let final_cost = env.get_cost_total(); + + ExecutionCost { + write_length: final_cost.write_length - initial_cost.write_length, + write_count: final_cost.write_count - initial_cost.write_count, + read_length: final_cost.read_length - initial_cost.read_length, + read_count: final_cost.read_count - initial_cost.read_count, + runtime: final_cost.runtime - initial_cost.runtime, + } +} + +#[test] +fn test_pox_4_costs() { + let workspace_root = Path::new(env!("CARGO_MANIFEST_DIR")).parent().unwrap(); + let pox_4_path = workspace_root + .join("contrib") + .join("boot-contracts-unit-tests") + .join("boot_contracts") + .join("pox-4.clar"); + let contract_source = std::fs::read_to_string(&pox_4_path) + .unwrap_or_else(|e| panic!("Failed to read pox-4.clar file at {:?}: {}", pox_4_path, e)); + + let contract_id = QualifiedContractIdentifier::transient(); + let epoch = StacksEpochId::Epoch32; + let clarity_version = ClarityVersion::Clarity3; + + let ast = crate::vm::ast::build_ast( + &contract_id, + &contract_source, + &mut (), + clarity_version, + epoch, + ) + .expect("Failed to build AST from pox-4.clar"); + + let cost_map = static_cost_from_ast(&ast, &clarity_version, epoch) + .expect("Failed to get static cost analysis for pox-4.clar"); + + // Check some functions in the cost map + let key_functions = vec![ + "stack-stx", + "delegate-stx", + "get-stacker-info", + "current-pox-reward-cycle", + "stack-aggregation-commit", + "stack-increase", + "stack-extend", + ]; + + for function_name in key_functions { + assert!( + cost_map.contains_key(function_name), + "Expected function '{}' to be present in cost map", + function_name + ); + + let (_cost, _trait_count) = cost_map.get(function_name).expect(&format!( + "Failed to get cost for function '{}'", + function_name + )); + } +} diff --git a/clarity/src/vm/tests/mod.rs b/clarity/src/vm/tests/mod.rs index 3d4408abec..bd2353273e 100644 --- a/clarity/src/vm/tests/mod.rs +++ b/clarity/src/vm/tests/mod.rs @@ -24,6 +24,8 @@ use crate::vm::contexts::OwnedEnvironment; pub use crate::vm::database::BurnStateDB; use crate::vm::database::MemoryBackingStore; +#[cfg(all(test, feature = "developer-mode"))] +mod analysis; mod assets; mod contracts; #[cfg(test)]