diff --git a/README.md b/README.md index 54f979c..c5b4ad8 100644 --- a/README.md +++ b/README.md @@ -174,10 +174,7 @@ GOAP needs to be benchmarked and monitored regularly because of exponential risk Though if well scoped you can manage hundred of Actions for 200µs per Agent. ## What's next? -- The simulateActionState functions, to check the effect of an action on a node, takes up to 40% of CPU and 40% of memory. -We need to refactorize this part, or find another logical path. -- Heuristic calculation in A* is done poorly, we need a better algorithm to improve performances. -- Benchmark a backward implementation like D*, to improve performances. +- Benchmark a backward implementation like D*. It might improve performances. ## Sources - https://web.archive.org/web/20230912145018/http://alumni.media.mit.edu/~jorkin/goap.html diff --git a/action.go b/action.go index 7918495..bae50cd 100644 --- a/action.go +++ b/action.go @@ -15,6 +15,11 @@ const ( DIVIDE ) +// Action represents a single action that an agent can perform to modify the world state. +// +// An action has preconditions (conditions) that must be met before it can be executed, +// and postconditions (effects) that describe how it modifies the world state. +// Actions have a cost that is used by the A* algorithm to find the optimal plan. type Action struct { name string cost float32 @@ -22,8 +27,18 @@ type Action struct { conditions Conditions effects Effects } + +// Actions is a collection of Action pointers. type Actions []*Action +// AddAction creates a new Action and appends it to the Actions collection. +// +// Parameters: +// - name: unique identifier for the action +// - cost: numeric cost used by pathfinding (lower costs are preferred) +// - repeatable: if false, the action can only be used once per plan +// - conditions: preconditions that must be satisfied before the action can be executed +// - effects: postconditions that describe how the action modifies the world state func (actions *Actions) AddAction(name string, cost float32, repeatable bool, conditions Conditions, effects Effects) { action := Action{ name: name, @@ -36,36 +51,50 @@ func (actions *Actions) AddAction(name string, cost float32, repeatable bool, co *actions = append(*actions, &action) } +// GetName returns the action's name identifier. func (action *Action) GetName() string { return action.name } +// GetEffects returns the action's effects (postconditions). func (action *Action) GetEffects() Effects { return action.effects } +// EffectInterface defines the interface that all effect types must implement. +// Effects describe how an action modifies the world state. type EffectInterface interface { - check(states states) bool - apply(data statesData) error + GetKey() StateKey + check(w world) bool + apply(w *world) error } +// Effect represents a numeric state modification for types constrained by Numeric. +// +// It supports arithmetic operators (SET, ADD, SUBTRACT, MULTIPLY, DIVIDE) to modify +// numeric state values. The effect is applied when an action is executed during planning. type Effect[T Numeric] struct { - Key StateKey - Operator arithmetic - Value T + Key StateKey // State key to modify + Operator arithmetic // Arithmetic operation to perform + Value T // Value to use in the operation +} + +// GetKey returns the state key that this effect modifies. +func (effect Effect[T]) GetKey() StateKey { + return effect.Key } -func (effect Effect[T]) check(states states) bool { - // Other operators than '=' mean the effect will have an impact of the states +func (effect Effect[T]) check(w world) bool { + // Other operators than '=' mean the effect will have an impact of the world if effect.Operator != SET { return false } - k := states.data.GetIndex(effect.Key) + k := w.states.GetIndex(effect.Key) if k < 0 { return false } - s := states.data[k] + s := w.states[k] if _, ok := s.(State[T]); !ok { return false @@ -74,23 +103,23 @@ func (effect Effect[T]) check(states states) bool { return s.(State[T]).Value == effect.Value } -func (effect Effect[T]) apply(data statesData) error { - k := data.GetIndex(effect.Key) +func (effect Effect[T]) apply(w *world) error { + k := w.states.GetIndex(effect.Key) if k < 0 { if slices.Contains([]arithmetic{SET, ADD}, effect.Operator) { - data = append(data, State[T]{Value: effect.Value}) + w.states = append(w.states, State[T]{Key: effect.Key, Value: effect.Value}) return nil } else if slices.Contains([]arithmetic{SUBSTRACT}, effect.Operator) { - data = append(data, State[T]{Value: -effect.Value}) + w.states = append(w.states, State[T]{Key: effect.Key, Value: -effect.Value}) return nil } - return fmt.Errorf("data does not exist") + return fmt.Errorf("w does not exist") } - if _, ok := data[k].(State[T]); !ok { + if _, ok := w.states[k].(State[T]); !ok { return fmt.Errorf("type does not match") } - state := data[k].(State[T]) + state := w.states[k].(State[T]) switch effect.Operator { case SET: state.Value = effect.Value @@ -104,112 +133,136 @@ func (effect Effect[T]) apply(data statesData) error { state.Value /= effect.Value } - data[k] = state + state.Store(w) return nil } +// EffectBool represents a boolean state modification. +// +// Only the SET operator is allowed for boolean effects. Attempting to use other +// operators (ADD, SUBTRACT, etc.) will result in an error when the effect is applied. type EffectBool struct { - Key StateKey - Value bool - Operator arithmetic + Key StateKey // State key to modify + Value bool // Boolean value to set + Operator arithmetic // Must be SET } -func (effectBool EffectBool) check(states states) bool { +// GetKey returns the state key that this effect modifies. +func (effectBool EffectBool) GetKey() StateKey { + return effectBool.Key +} + +func (effectBool EffectBool) check(w world) bool { // Other operators than '=' is not allowed if effectBool.Operator != SET { return false } - k := states.data.GetIndex(effectBool.Key) + k := w.states.GetIndex(effectBool.Key) if k < 0 { return false } - if _, ok := states.data[k].(State[bool]); !ok { + if _, ok := w.states[k].(State[bool]); !ok { return false } - s := states.data[k].(State[bool]) + s := w.states[k].(State[bool]) return s.Value == effectBool.Value } -func (effectBool EffectBool) apply(data statesData) error { +func (effectBool EffectBool) apply(w *world) error { if effectBool.Operator != SET { return fmt.Errorf("operation %v not allowed on bool type", effectBool.Operator) } - k := data.GetIndex(effectBool.Key) + k := w.states.GetIndex(effectBool.Key) if k < 0 { - data = append(data, State[bool]{Value: effectBool.Value}) + w.states = append(w.states, State[bool]{Key: effectBool.Key, Value: effectBool.Value}) return nil } - if _, ok := data[k].(State[bool]); !ok { + if _, ok := w.states[k].(State[bool]); !ok { return fmt.Errorf("type does not match") } - state := data[k].(State[bool]) + state := w.states[k].(State[bool]) state.Value = effectBool.Value - data[k] = state + + state.Store(w) return nil } +// EffectString represents a string state modification. +// +// Supports SET (replace string) and ADD (concatenate) operators. Other operators +// (SUBTRACT, MULTIPLY, DIVIDE) are not allowed and will result in an error. type EffectString struct { - Key StateKey - Value string - Operator arithmetic + Key StateKey // State key to modify + Value string // String value to use + Operator arithmetic // Allowed: SET, ADD } -func (effectString EffectString) check(states states) bool { - k := states.data.GetIndex(effectString.Key) +// GetKey returns the state key that this effect modifies. +func (effectString EffectString) GetKey() StateKey { + return effectString.Key +} + +func (effectString EffectString) check(w world) bool { + k := w.states.GetIndex(effectString.Key) if k < 0 { return false } - if _, ok := states.data[k].(State[string]); !ok { + if _, ok := w.states[k].(State[string]); !ok { return false } - s := states.data[k].(State[string]) + s := w.states[k].(State[string]) return s.Value == effectString.Value } -func (effectString EffectString) apply(data statesData) error { +func (effectString EffectString) apply(w *world) error { if !slices.Contains([]arithmetic{SET, ADD}, effectString.Operator) { return fmt.Errorf("arithmetic operation %v not allowed on string type", effectString.Operator) } - k := data.GetIndex(effectString.Key) + k := w.states.GetIndex(effectString.Key) if k < 0 { - data = append(data, State[string]{Value: effectString.Value}) + w.states = append(w.states, State[string]{Key: effectString.Key, Value: effectString.Value}) return nil } - if _, ok := data[k].(State[string]); !ok { + if _, ok := w.states[k].(State[string]); !ok { return fmt.Errorf("type does not match") } - state := data[k].(State[string]) + state := w.states[k].(State[string]) switch effectString.Operator { case SET: state.Value = effectString.Value case ADD: state.Value = fmt.Sprint(state.Value, effectString.Value) } - data[k] = state + + state.Store(w) return nil } +// EffectFn is a function type for custom procedural effects that directly modify the agent. +// This allows for effects that cannot be expressed through simple state modifications. type EffectFn func(agent *Agent) +// Effects is a collection of EffectInterface implementations that describe how +// an action modifies the world state. type Effects []EffectInterface -// If all the effects already exist in states, +// If all the effects already exist in world, // it is probably not a good path -func (effects Effects) satisfyStates(states states) bool { +func (effects Effects) satisfyStates(w world) bool { for _, effect := range effects { - if !effect.check(states) { + if !effect.check(w) { return false } } @@ -217,16 +270,14 @@ func (effects Effects) satisfyStates(states states) bool { return true } -func (effects Effects) apply(states states) (statesData, error) { - data := slices.Clone(states.data) - +func (effects Effects) apply(w *world) error { for _, effect := range effects { - err := effect.apply(data) + err := effect.apply(w) if err != nil { - return nil, err + return err } } - return data, nil + return nil } diff --git a/action_test.go b/action_test.go new file mode 100644 index 0000000..9f8485f --- /dev/null +++ b/action_test.go @@ -0,0 +1,167 @@ +package goapai + +import "testing" + +func TestActions_AddAction(t *testing.T) { + tests := []struct { + name string + actionsToAdd []struct{ name string; cost float32; repeatable bool } + wantCount int + checkFirst bool + wantFirstName string + wantFirstCost float32 + wantRepeatble bool + }{ + { + name: "single action", + actionsToAdd: []struct{ name string; cost float32; repeatable bool }{ + {name: "test", cost: 1.5, repeatable: false}, + }, + wantCount: 1, + checkFirst: true, + wantFirstName: "test", + wantFirstCost: 1.5, + wantRepeatble: false, + }, + { + name: "multiple actions", + actionsToAdd: []struct{ name string; cost float32; repeatable bool }{ + {name: "action1", cost: 1.0, repeatable: true}, + {name: "action2", cost: 2.0, repeatable: false}, + }, + wantCount: 2, + checkFirst: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actions := Actions{} + + for _, a := range tt.actionsToAdd { + actions.AddAction(a.name, a.cost, a.repeatable, Conditions{}, Effects{}) + } + + if len(actions) != tt.wantCount { + t.Errorf("Expected %d actions, got %d", tt.wantCount, len(actions)) + } + + if tt.checkFirst { + action := actions[0] + if action.name != tt.wantFirstName { + t.Errorf("Expected name '%s', got '%s'", tt.wantFirstName, action.name) + } + if action.cost != tt.wantFirstCost { + t.Errorf("Expected cost %f, got %f", tt.wantFirstCost, action.cost) + } + if action.repeatable != tt.wantRepeatble { + t.Errorf("Expected repeatable to be %v", tt.wantRepeatble) + } + } + }) + } +} + +func TestAction_GetName(t *testing.T) { + tests := []struct { + name string + actionName string + want string + }{ + { + name: "basic name", + actionName: "my_action", + want: "my_action", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actions := Actions{} + actions.AddAction(tt.actionName, 1.0, false, Conditions{}, Effects{}) + + got := actions[0].GetName() + if got != tt.want { + t.Errorf("Expected '%s', got '%s'", tt.want, got) + } + }) + } +} + +func TestAction_GetEffects(t *testing.T) { + tests := []struct { + name string + effects Effects + wantCount int + }{ + { + name: "single effect", + effects: Effects{ + Effect[int]{Key: 1, Value: 10, Operator: SET}, + }, + wantCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actions := Actions{} + actions.AddAction("test", 1.0, false, Conditions{}, tt.effects) + + retrieved := actions[0].GetEffects() + if len(retrieved) != tt.wantCount { + t.Errorf("Expected %d effect(s), got %d", tt.wantCount, len(retrieved)) + } + }) + } +} + +func TestEffect_Check(t *testing.T) { + tests := []struct { + name string + stateValue int + effectKey StateKey + effectVal int + operator arithmetic + want bool + }{ + { + name: "match with SET operator", + stateValue: 100, + effectKey: 1, + effectVal: 100, + operator: SET, + want: true, + }, + { + name: "no match with SET operator", + stateValue: 100, + effectKey: 1, + effectVal: 200, + operator: SET, + want: false, + }, + { + name: "non-SET operator always false", + stateValue: 100, + effectKey: 1, + effectVal: 100, + operator: ADD, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, tt.effectKey, tt.stateValue) + + effect := Effect[int]{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} + got := effect.check(agent.w) + + if got != tt.want { + t.Errorf("Expected %v, got %v", tt.want, got) + } + }) + } +} diff --git a/agent.go b/agent.go index fba27a1..9094ed1 100644 --- a/agent.go +++ b/agent.go @@ -1,8 +1,65 @@ +// Package goapai implements a microlithic Goal-Oriented Action Planning (GOAP) system for AI agents. +// +// GOAP is a planning technique where agents dynamically generate action sequences to achieve +// goals based on the current world state. This implementation uses A* pathfinding to find +// optimal plans. +// +// # Key Concepts +// +// Agent: The central entity that maintains world state, goals, actions, and sensors. +// +// State: Key-value pairs representing the world state. States support numeric types, +// booleans, and strings, identified by compact StateKey (uint16) values. +// +// Action: Operations that modify the world state. Each action has preconditions (Conditions) +// and postconditions (Effects), plus a cost for pathfinding optimization. +// +// Goal: Desired world states with priority functions. The planner always works on the +// highest priority goal. +// +// Sensor: External data sources used in goal prioritization and procedural conditions +// without duplicating data during planning. +// +// # Example Usage +// +// // Create an agent +// agent := goapai.CreateAgent( +// goapai.Goals{ +// "survive": { +// Conditions: goapai.Conditions{ +// &goapai.Condition[int]{Key: 1, Value: 50, Operator: goapai.UPPER_OR_EQUAL}, +// }, +// PriorityFn: func(sensors goapai.Sensors) float32 { +// return 1.0 +// }, +// }, +// }, +// goapai.Actions{}, +// ) +// +// // Set initial state +// goapai.SetState[int](&agent, 1, 20) // health = 20 +// +// // Add actions +// agent.actions.AddAction("heal", 1.0, false, +// goapai.Conditions{}, +// goapai.Effects{ +// goapai.Effect[int]{Key: 1, Operator: goapai.ADD, Value: 50}, +// }, +// ) +// +// // Generate plan +// goalName, plan := goapai.GetPlan(agent, 10) package goapai +// Agent represents an AI agent that uses GOAP (Goal-Oriented Action Planning) to make decisions. +// +// An agent maintains a world state, a set of goals with priorities, available actions, +// and sensors for external data. The agent uses A* pathfinding to generate optimal plans +// that achieve its highest priority goal. type Agent struct { actions Actions - states states + w world sensors Sensors goals Goals } @@ -11,9 +68,17 @@ type goalInterface struct { Conditions []ConditionInterface PriorityFn GoalPriorityFn } + +// GoalName is a unique identifier for a goal. type GoalName string + +// Goals is a map of goal names to their definitions (conditions and priority function). type Goals map[GoalName]goalInterface +// CreateAgent creates and initializes a new Agent with the given goals and actions. +// +// The agent is initialized with an empty world state and no sensors. Use SetState +// to initialize the world state and SetSensor to add sensor data. func CreateAgent(goals Goals, actions Actions) Agent { agent := Agent{ actions: actions, @@ -21,22 +86,43 @@ func CreateAgent(goals Goals, actions Actions) Agent { sensors: Sensors{}, } - states := states{ - Agent: &agent, - data: statesData{}, + states := world{ + Agent: &agent, + states: states{}, } - agent.states = states + agent.w = states return agent } +// SetState adds or updates a state value in the agent's world state. +// +// State values can be numeric types (int, int8, uint8, uint64, float64), bool, or string. +// Each state is identified by a unique StateKey. Multiple calls with the same key will +// create duplicate states; this is generally not recommended. +// +// Example: +// +// SetState[int](&agent, 1, 100) // Set state key 1 to integer 100 +// SetState[bool](&agent, 2, true) // Set state key 2 to boolean true +// SetState[string](&agent, 3, "foo") // Set state key 3 to string "foo" func SetState[T Numeric | bool | string](agent *Agent, key StateKey, value T) { - agent.states.data = append(agent.states.data, State[T]{ + agent.w.states = append(agent.w.states, State[T]{ Key: key, Value: value, }) } +// SetSensor adds or updates a sensor value for the agent. +// +// Sensors provide external data that can be used in goal priority functions and +// procedural conditions (ConditionFn) without duplicating data during planning. +// Unlike world state, sensors are not modified during plan simulation. +// +// Example: +// +// SetSensor(&agent, "health", 100) +// SetSensor(&agent, "enemy_visible", true) func SetSensor[T Sensor](agent *Agent, name string, value T) { agent.sensors[name] = value } diff --git a/agent_test.go b/agent_test.go new file mode 100644 index 0000000..30810de --- /dev/null +++ b/agent_test.go @@ -0,0 +1,190 @@ +package goapai + +import "testing" + +func TestCreateAgent(t *testing.T) { + tests := []struct { + name string + goals Goals + actions Actions + wantActionCnt int + wantGoalCnt int + }{ + { + name: "agent with goal and action", + goals: Goals{ + "test_goal": { + Conditions: Conditions{ + &ConditionBool{Key: 1, Value: true}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + }, + actions: func() Actions { + actions := Actions{} + actions.AddAction("test_action", 1.0, false, Conditions{}, Effects{}) + return actions + }(), + wantActionCnt: 1, + wantGoalCnt: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(tt.goals, tt.actions) + + if len(agent.actions) != tt.wantActionCnt { + t.Errorf("Expected %d action(s), got %d", tt.wantActionCnt, len(agent.actions)) + } + + if len(agent.goals) != tt.wantGoalCnt { + t.Errorf("Expected %d goal(s), got %d", tt.wantGoalCnt, len(agent.goals)) + } + + if agent.sensors == nil { + t.Error("Expected sensors to be initialized") + } + + if agent.w.Agent == nil { + t.Error("Expected world.Agent to be non-nil") + } + }) + } +} + +func TestSetState(t *testing.T) { + tests := []struct { + name string + setupFunc func(*Agent) + checkFunc func(*testing.T, Agent) + }{ + { + name: "int state", + setupFunc: func(a *Agent) { + SetState[int](a, 1, 42) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.w.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.w.states)) + } + state := a.w.states[0].(State[int]) + if state.Key != 1 || state.Value != 42 { + t.Errorf("Expected key=1, value=42, got key=%d, value=%d", state.Key, state.Value) + } + }, + }, + { + name: "bool state", + setupFunc: func(a *Agent) { + SetState[bool](a, 2, true) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.w.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.w.states)) + } + state := a.w.states[0].(State[bool]) + if state.Key != 2 || state.Value != true { + t.Errorf("Expected key=2, value=true, got key=%d, value=%v", state.Key, state.Value) + } + }, + }, + { + name: "string state", + setupFunc: func(a *Agent) { + SetState[string](a, 3, "test") + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.w.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.w.states)) + } + state := a.w.states[0].(State[string]) + if state.Key != 3 || state.Value != "test" { + t.Errorf("Expected key=3, value='test', got key=%d, value='%s'", state.Key, state.Value) + } + }, + }, + { + name: "multiple world", + setupFunc: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[bool](a, 2, false) + SetState[string](a, 3, "hello") + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.w.states) != 3 { + t.Errorf("Expected 3 world, got %d", len(a.w.states)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + tt.setupFunc(&agent) + tt.checkFunc(t, agent) + }) + } +} + +func TestSetSensor(t *testing.T) { + type TestEntity struct { + health int + } + + tests := []struct { + name string + setupFunc func(*Agent) + checkFunc func(*testing.T, Agent) + }{ + { + name: "single sensor", + setupFunc: func(a *Agent) { + entity := &TestEntity{health: 100} + SetSensor(a, "entity", entity) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.sensors) != 1 { + t.Errorf("Expected 1 sensor, got %d", len(a.sensors)) + } + retrieved := a.sensors.GetSensor("entity").(*TestEntity) + if retrieved.health != 100 { + t.Errorf("Expected health 100, got %d", retrieved.health) + } + }, + }, + { + name: "multiple sensors", + setupFunc: func(a *Agent) { + SetSensor(a, "sensor1", "value1") + SetSensor(a, "sensor2", 42) + SetSensor(a, "sensor3", true) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.sensors) != 3 { + t.Errorf("Expected 3 sensors, got %d", len(a.sensors)) + } + if a.sensors.GetSensor("sensor1").(string) != "value1" { + t.Error("sensor1 value mismatch") + } + if a.sensors.GetSensor("sensor2").(int) != 42 { + t.Error("sensor2 value mismatch") + } + if a.sensors.GetSensor("sensor3").(bool) != true { + t.Error("sensor3 value mismatch") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + tt.setupFunc(&agent) + tt.checkFunc(t, agent) + }) + } +} diff --git a/astar.go b/astar.go index 1d358d2..310df65 100644 --- a/astar.go +++ b/astar.go @@ -1,63 +1,53 @@ package goapai import ( + "container/heap" "slices" - "sync" ) type node struct { *Action - states states + world world parentNode *node cost float32 totalCost float32 heuristic float32 depth uint16 + heapIndex int // Index in the heap, needed for heap.Fix + closed bool // true = closed node, false = open node } -var nodesPool = sync.Pool{ - New: func() any { - return make([]*node, 0, 32) - }, -} - -func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan { +func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { availableActions := getImpactingActions(from, actions) - openNodes := nodesPool.Get().([]*node) - closedNodes := nodesPool.Get().([]*node) - - defer func() { - nodesPool.Put(openNodes[:0]) - nodesPool.Put(closedNodes[:0]) - }() - data := slices.Clone(from.data) - data.sort() - openNodes = append(openNodes, &node{ + startNode := &node{ Action: &Action{}, - states: states{ - Agent: from.Agent, - data: data, - hash: data.hashStates(), + world: world{ + Agent: from.Agent, + states: slices.Clone(from.states), + hash: from.hash, }, parentNode: nil, - cost: 0, - totalCost: 0, - heuristic: 0, - depth: 0, - }) - - for openNodeKey := 0; openNodeKey != -1; openNodeKey = getLessCostlyNodeKey(openNodes) { - parentNode := openNodes[openNodeKey] + heapIndex: -1, + closed: false, + } + + nodesHeap := nodeHeap{} + heap.Init(&nodesHeap) + heap.Push(&nodesHeap, startNode) + + for nodesHeap.Len() > 0 { + parentNode := heap.Pop(&nodesHeap).(*node) + if parentNode.depth > uint16(maxDepth) { - openNodes = append(openNodes[:openNodeKey], openNodes[openNodeKey+1:]...) - closedNodes = append(closedNodes, parentNode) + parentNode.closed = true + heap.Fix(&nodesHeap, parentNode.heapIndex) continue } // Simulate world state, and check if we are at current state - if countMissingGoal(goal, parentNode.states) == 0 { + if countMissingGoal(goal, parentNode.world) == 0 { return buildPlanFromNode(parentNode) } @@ -66,64 +56,68 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan continue } - if !action.conditions.Check(parentNode.states) { + if !action.conditions.Check(parentNode.world) { continue } - simulatedStates, ok := simulateActionState(action, parentNode.states) + simulatedStates, ok := simulateActionState(action, parentNode.world) if !ok { continue } - if nodeKey, found := fetchNode(openNodes, simulatedStates); found { - node := openNodes[nodeKey] - if (parentNode.cost + action.cost) < node.cost { - node.Action = action - node.states = simulatedStates - node.parentNode = parentNode - node.cost = parentNode.cost + action.cost - node.totalCost = parentNode.cost + action.cost + node.heuristic - node.depth = parentNode.depth + 1 - - openNodes[nodeKey] = node + currentNode, found := fetchNodeInHeap(nodesHeap, simulatedStates) + // Check if node exists in open nodes (closed=false) + if found && !currentNode.closed { + if (parentNode.cost + action.cost) < currentNode.cost { + currentNode.Action = action + currentNode.world = simulatedStates + currentNode.parentNode = parentNode + currentNode.cost = parentNode.cost + action.cost + currentNode.totalCost = parentNode.cost + action.cost + currentNode.heuristic + currentNode.depth = parentNode.depth + 1 + + // Fix heap position after cost update + heap.Fix(&nodesHeap, currentNode.heapIndex) } - } else if nodeKey, found := fetchNode(closedNodes, simulatedStates); found { - node := closedNodes[nodeKey] - if (parentNode.cost + action.cost) < node.cost { - node.Action = action - node.states = simulatedStates - node.parentNode = parentNode - node.cost = parentNode.cost + action.cost - node.totalCost = parentNode.cost + action.cost + node.heuristic - node.depth = parentNode.depth + 1 - - openNodes[openNodeKey] = node - closedNodes = append(closedNodes[:nodeKey], closedNodes[nodeKey+1:]...) + } else if found && currentNode.closed { + // Node was closed, reopen it with better cost + if (parentNode.cost + action.cost) < currentNode.cost { + currentNode.Action = action + currentNode.world = simulatedStates + currentNode.parentNode = parentNode + currentNode.cost = parentNode.cost + action.cost + currentNode.totalCost = parentNode.cost + action.cost + currentNode.heuristic + currentNode.depth = parentNode.depth + 1 + currentNode.closed = false // Reopen + + // Fix heap position + heap.Fix(&nodesHeap, currentNode.heapIndex) } } else { + // New node heuristic := computeHeuristic(from, goal, simulatedStates) - openNodes = append(openNodes, &node{ + newNode := &node{ Action: action, - states: simulatedStates, + world: simulatedStates, parentNode: parentNode, cost: parentNode.cost + action.cost, totalCost: parentNode.cost + action.cost + heuristic, heuristic: heuristic, depth: parentNode.depth + 1, - }) + heapIndex: -1, + closed: false, + } + heap.Push(&nodesHeap, newNode) } } - - openNodes = append(openNodes[:openNodeKey], openNodes[openNodeKey+1:]...) - closedNodes = append(closedNodes, parentNode) } return Plan{} } -// All the actions similar to initial states are useless: +// All the actions similar to initial world are useless: // we consider they are not going towards the goal and are dead end -func getImpactingActions(from states, actions Actions) Actions { +func getImpactingActions(from world, actions Actions) Actions { var availableActions Actions for _, action := range actions { @@ -135,26 +129,13 @@ func getImpactingActions(from states, actions Actions) Actions { return availableActions } -func getLessCostlyNodeKey(openNodes []*node) int { - lowestKey := -1 - - for key, node := range openNodes { - if lowestKey < 0 || node.totalCost < openNodes[lowestKey].totalCost { - lowestKey = key +func fetchNodeInHeap(heap nodeHeap, w world) (*node, bool) { + for _, n := range heap { + if n.world.Check(w) { + return n, true } } - - return lowestKey -} - -func fetchNode(nodes []*node, states states) (int, bool) { - for k, n := range nodes { - if n.states.Check(states) { - return k, true - } - } - - return 0, false + return nil, false } func buildPlanFromNode(node *node) Plan { @@ -170,24 +151,20 @@ func buildPlanFromNode(node *node) Plan { return plan } -func simulateActionState(action *Action, nodeStates states) (states, bool) { +func simulateActionState(action *Action, w world) (world, bool) { /* If action effects implies no changes to current worldState, then avoid generating huge chunks of memory */ - if action.effects.satisfyStates(nodeStates) { - return states{}, false + if action.effects.satisfyStates(w) { + return world{}, false } - data, err := action.effects.apply(nodeStates) + w.states = slices.Clone(w.states) + err := action.effects.apply(&w) if err != nil { - return states{}, false + return world{}, false } - data.sort() - return states{ - Agent: nodeStates.Agent, - data: data, - hash: data.hashStates(), - }, true + return w, true } func allowedRepetition(action *Action, parentNode *node) bool { @@ -207,10 +184,10 @@ func allowedRepetition(action *Action, parentNode *node) bool { return true } -func countMissingGoal(goal goalInterface, states states) int { +func countMissingGoal(goal goalInterface, w world) int { count := 0 for _, condition := range goal.Conditions { - if !condition.Check(states) { + if !condition.Check(w) { count++ } } @@ -219,15 +196,31 @@ func countMissingGoal(goal goalInterface, states states) int { } /* -A very simple (empiristic) model for h using: - - how much required states are met - -We try to be conservative and reduce the number of steps +Improved heuristic using numeric distance calculation: + - For each goal condition, calculate the numeric distance between current value and target + - Sum all distances to get total heuristic + - This provides much better guidance than simple binary satisfied/unsatisfied check */ -func computeHeuristic(fromStates states, goal goalInterface, states states) float32 { - missingGoalsCount := float32(countMissingGoal(goal, states)) +func computeHeuristic(from world, goal goalInterface, w world) float32 { + var totalDistance float32 - h := missingGoalsCount + for _, condition := range goal.Conditions { + key := condition.GetKey() + stateIndex := w.states.GetIndex(key) + + if stateIndex >= 0 { + // State exists, calculate actual distance + state := w.states[stateIndex] + distance := state.Distance(condition) + totalDistance += distance + } else { + // State doesn't exist, use pessimistic estimate + // If the condition is not satisfied and state doesn't exist, assume distance of 1 + if !condition.Check(w) { + totalDistance += 1.0 + } + } + } - return h + return totalDistance } diff --git a/astar_test.go b/astar_test.go new file mode 100644 index 0000000..a42888f --- /dev/null +++ b/astar_test.go @@ -0,0 +1,468 @@ +package goapai + +import ( + "slices" + "testing" +) + +// Test getImpactingActions +func TestGetImpactingActions(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, false) + + actions := Actions{} + // Action that changes state - should be included + actions.AddAction("change", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 200, Operator: SET}, + }) + // Action with effects matching current state - should be excluded + actions.AddAction("no_change", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + EffectBool{Key: 2, Value: false, Operator: SET}, + }) + + impacting := getImpactingActions(agent.w, actions) + + if len(impacting) != 1 { + t.Errorf("Expected 1 impacting action, got %d", len(impacting)) + } + + if impacting[0].name != "change" { + t.Errorf("Expected 'change' action, got '%s'", impacting[0].name) + } +} + +func TestGetImpactingActions_AllImpacting(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + actions := Actions{} + actions.AddAction("action1", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 200, Operator: SET}, + }) + actions.AddAction("action2", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 50, Operator: SET}, + }) + + impacting := getImpactingActions(agent.w, actions) + + if len(impacting) != 2 { + t.Errorf("Expected 2 impacting actions, got %d", len(impacting)) + } +} + +// Test buildPlanFromNode +func TestBuildPlanFromNode(t *testing.T) { + action1 := &Action{name: "action1", cost: 1.0} + action2 := &Action{name: "action2", cost: 1.0} + action3 := &Action{name: "action3", cost: 1.0} + + // Build a chain: nil -> node1 -> node2 -> node3 + node1 := &node{Action: action1, parentNode: nil, depth: 1} + node2 := &node{Action: action2, parentNode: node1, depth: 2} + node3 := &node{Action: action3, parentNode: node2, depth: 3} + + plan := buildPlanFromNode(node3) + + if len(plan) != 3 { + t.Errorf("Expected plan length 3, got %d", len(plan)) + } + + if plan[0].name != "action1" { + t.Errorf("Expected first action to be 'action1', got '%s'", plan[0].name) + } + if plan[1].name != "action2" { + t.Errorf("Expected second action to be 'action2', got '%s'", plan[1].name) + } + if plan[2].name != "action3" { + t.Errorf("Expected third action to be 'action3', got '%s'", plan[2].name) + } +} + +func TestBuildPlanFromNode_SingleNode(t *testing.T) { + action1 := &Action{name: "action1", cost: 1.0} + node1 := &node{Action: action1, parentNode: nil, depth: 1} + + plan := buildPlanFromNode(node1) + + if len(plan) != 1 { + t.Errorf("Expected plan length 1, got %d", len(plan)) + } +} + +// Test simulateActionState +func TestSimulateActionState(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + action := &Action{ + effects: Effects{ + Effect[int]{Key: 1, Value: 50, Operator: ADD}, + }, + } + + newStates, ok := simulateActionState(action, agent.w) + if !ok { + t.Error("Expected simulation to succeed") + } + + idx := newStates.states.GetIndex(1) + if idx < 0 { + t.Error("Expected to find key 1 in new world") + } + + if newStates.states[idx].(State[int]).Value != 150 { + t.Errorf("Expected value 150, got %d", newStates.states[idx].(State[int]).Value) + } +} + +func TestSimulateActionState_NoChange(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + action := &Action{ + effects: Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }, + } + + _, ok := simulateActionState(action, agent.w) + if ok { + t.Error("Expected simulation to fail when effects match current state") + } +} + +// Test allowedRepetition +func TestAllowedRepetition_Repeatable(t *testing.T) { + action := &Action{name: "test", repeatable: true} + parentNode := &node{Action: action} + + if !allowedRepetition(action, parentNode) { + t.Error("Expected repeatable action to be allowed") + } +} + +func TestAllowedRepetition_NonRepeatable_NotUsed(t *testing.T) { + action1 := &Action{name: "action1", repeatable: false} + action2 := &Action{name: "action2", repeatable: false} + + node1 := &node{Action: action2, parentNode: nil} + + if !allowedRepetition(action1, node1) { + t.Error("Expected non-repeated action to be allowed") + } +} + +func TestAllowedRepetition_NonRepeatable_AlreadyUsed(t *testing.T) { + action := &Action{name: "test", repeatable: false} + + node1 := &node{Action: action, parentNode: nil} + node2 := &node{Action: &Action{name: "other"}, parentNode: node1} + + if allowedRepetition(action, node2) { + t.Error("Expected repeated non-repeatable action to be disallowed") + } +} + +// Test countMissingGoal +func TestCountMissingGoal_AllMet(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, true) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + } + + count := countMissingGoal(goal, agent.w) + if count != 0 { + t.Errorf("Expected 0 missing goals, got %d", count) + } +} + +func TestCountMissingGoal_OneMissing(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, false) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + } + + count := countMissingGoal(goal, agent.w) + if count != 1 { + t.Errorf("Expected 1 missing goal, got %d", count) + } +} + +func TestCountMissingGoal_AllMissing(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + SetState[bool](&agent, 2, false) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + } + + count := countMissingGoal(goal, agent.w) + if count != 2 { + t.Errorf("Expected 2 missing goals, got %d", count) + } +} + +// Test computeHeuristic +func TestComputeHeuristic(t *testing.T) { + fromAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](&fromAgent, 1, 0) + + currentAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](¤tAgent, 1, 50) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + heuristic := computeHeuristic(fromAgent.w, goal, currentAgent.w) + if heuristic <= 0 { + t.Error("Expected positive heuristic for unmet goal") + } +} + +func TestComputeHeuristic_GoalMet(t *testing.T) { + fromAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](&fromAgent, 1, 0) + + currentAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](¤tAgent, 1, 100) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + heuristic := computeHeuristic(fromAgent.w, goal, currentAgent.w) + if heuristic != 0 { + t.Errorf("Expected 0 heuristic for met goal, got %f", heuristic) + } +} + +// Test astar integration +func TestAstar_SimpleGoal(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 10, Operator: ADD}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 30, Operator: EQUAL}, + }, + } + + plan := astar(agent.w, goal, actions, 10) + + // Plan includes the root node with empty action + if len(plan) != 4 { + t.Errorf("Expected plan with 4 actions (including root), got %d", len(plan)) + } +} + +func TestAstar_UnreachableGoal(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + // No actions available + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.w, goal, actions, 10) + + if len(plan) != 0 { + t.Errorf("Expected empty plan for unreachable goal, got %d actions", len(plan)) + } +} + +func TestAstar_MaxDepth(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 1, Operator: ADD}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + // Max depth of 5 should prevent reaching goal of 100 + plan := astar(agent.w, goal, actions, 5) + + if len(plan) > 5 { + t.Errorf("Expected plan to respect maxDepth of 5, got %d actions", len(plan)) + } +} + +func TestAstar_AlreadyAtGoal(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + actions := Actions{} + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.w, goal, actions, 10) + + // Plan includes root node with empty action when already at goal + if len(plan) != 1 { + t.Errorf("Expected plan with 1 action (root) when already at goal, got %d actions", len(plan)) + } +} + +func TestAstar_PreferLowerCost(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("expensive", 10.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }) + actions.AddAction("cheap", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.w, goal, actions, 10) + + // Plan includes root + action + if len(plan) != 2 { + t.Errorf("Expected plan with 2 actions (root + 1), got %d", len(plan)) + } + + if plan[1].name != "cheap" { + t.Errorf("Expected cheaper action to be chosen, got '%s'", plan[1].name) + } +} + +func TestAstar_RespectConditions(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + SetState[bool](&agent, 2, false) + + actions := Actions{} + // This action requires key 2 to be true + actions.AddAction("conditional", 1.0, false, Conditions{ + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }) + // This action enables the conditional action + actions.AddAction("enabler", 1.0, false, Conditions{}, Effects{ + EffectBool{Key: 2, Value: true, Operator: SET}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.w, goal, actions, 10) + + // Plan includes root + 2 actions + if len(plan) != 3 { + t.Errorf("Expected plan with 3 actions (root + 2), got %d", len(plan)) + } + + // Second action (index 1) should be the enabler + if plan[1].name != "enabler" { + t.Errorf("Expected second action to be 'enabler', got '%s'", plan[1].name) + } + // Third action (index 2) should be the conditional one + if plan[2].name != "conditional" { + t.Errorf("Expected third action to be 'conditional', got '%s'", plan[2].name) + } +} + +func TestAstar_NonRepeatableActions(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("increment", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 10, Operator: ADD}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 30, Operator: EQUAL}, + }, + } + + plan := astar(agent.w, goal, actions, 10) + + // Non-repeatable action can only be used once, so goal is unreachable + if len(plan) != 0 { + t.Errorf("Expected empty plan for non-repeatable action, got %d actions", len(plan)) + } +} + +func TestAstar_DataCloning(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + originalData := slices.Clone(agent.w.states) + + actions := Actions{} + actions.AddAction("modify", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 200, Operator: SET}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 200, Operator: EQUAL}, + }, + } + + _ = astar(agent.w, goal, actions, 10) + + // Original state should not be modified + if len(agent.w.states) != len(originalData) { + t.Error("Original state was modified") + } + if agent.w.states[0].(State[int]).Value != originalData[0].(State[int]).Value { + t.Error("Original state values were modified") + } +} diff --git a/benchmark/goapai_test.go b/benchmark/goapai_test.go index 0a00426..a82f369 100644 --- a/benchmark/goapai_test.go +++ b/benchmark/goapai_test.go @@ -1,7 +1,6 @@ package benchmark import ( - "fmt" "goapai" "testing" ) @@ -63,11 +62,11 @@ func BenchmarkGoapAI(b *testing.B) { entity.agent = goapai.CreateAgent(goals, actions) goapai.SetSensor(&entity.agent, "entity", &entity) - goapai.SetState[int](&entity.agent, ATTRIBUTE_2, 0) goapai.SetState[int](&entity.agent, ATTRIBUTE_1, 80) + goapai.SetState[int](&entity.agent, ATTRIBUTE_2, 0) goapai.SetState[int](&entity.agent, ATTRIBUTE_3, 0) - // Write to the trace file. + //Write to the trace file. //f, _ := os.Create("trace.out") //fcpu, _ := os.Create(`cpu.prof`) //fheap, _ := os.Create(`heap.prof`) @@ -76,12 +75,10 @@ func BenchmarkGoapAI(b *testing.B) { //pprof.WriteHeapProfile(fheap) //trace.Start(f) - var lastPlan goapai.Plan for b.Loop() { //goapai.GetPlan(entity.agent, 15) - _, lastPlan = goapai.GetPlan(entity.agent, 15) + goapai.GetPlan(entity.agent, 15) } - fmt.Println(len(lastPlan)) //defer f.Close() //defer fcpu.Close() diff --git a/distance.go b/distance.go new file mode 100644 index 0000000..f562d1a --- /dev/null +++ b/distance.go @@ -0,0 +1,110 @@ +package goapai + +// Distance calculates the heuristic distance between the current state value and a condition's target value. +// +// It returns 0 if the condition is already satisfied, otherwise returns the numeric distance +// needed to satisfy the condition. This method is used by the A* algorithm to compute heuristics +// for pathfinding. +// +// For numeric types (int, int8, uint8, uint64, float64), the distance is calculated based on +// the operator type (EQUAL, UPPER, LOWER, etc.). For bool and string types, the distance is +// either 0 (satisfied) or 1 (not satisfied). +// +// If the condition's key doesn't match the state's key, or if types don't match, returns 0. +func (state State[T]) Distance(condition ConditionInterface) float32 { + // Check if the condition key matches + if state.Key != condition.GetKey() { + return 0 + } + + // Handle different condition types + switch cond := condition.(type) { + case *Condition[int8]: + if v, ok := any(state.Value).(int8); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[int]: + if v, ok := any(state.Value).(int); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[uint8]: + if v, ok := any(state.Value).(uint8); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[uint64]: + if v, ok := any(state.Value).(uint64); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[float64]: + if v, ok := any(state.Value).(float64); ok { + return calculateNumericDistance(v, cond.Value, cond.Operator) + } + case *ConditionBool: + if v, ok := any(state.Value).(bool); ok { + if cond.Operator == EQUAL { + if v == cond.Value { + return 0 + } + return 1 + } else if cond.Operator == NOT_EQUAL { + if v != cond.Value { + return 0 + } + return 1 + } + } + case *ConditionString: + if v, ok := any(state.Value).(string); ok { + if cond.Operator == EQUAL { + if v == cond.Value { + return 0 + } + return 1 + } else if cond.Operator == NOT_EQUAL { + if v != cond.Value { + return 0 + } + return 1 + } + } + } + + return 0 +} + +// calculateNumericDistance computes the distance for numeric conditions based on operator +func calculateNumericDistance(current, target float64, op operator) float32 { + switch op { + case EQUAL: + if current < target { + return float32(target - current) + } + return float32(current - target) + case NOT_EQUAL: + if current == target { + return 1.0 + } + return 0.0 + case UPPER_OR_EQUAL: + if current < target { + return float32(target - current) + } + return 0.0 + case UPPER: + if current <= target { + return float32(target - current + 1) + } + return 0.0 + case LOWER_OR_EQUAL: + if current > target { + return float32(current - target) + } + return 0.0 + case LOWER: + if current >= target { + return float32(current - target + 1) + } + return 0.0 + } + return 0.0 +} diff --git a/distance_test.go b/distance_test.go new file mode 100644 index 0000000..fbcbafb --- /dev/null +++ b/distance_test.go @@ -0,0 +1,216 @@ +package goapai + +import "testing" + +func TestState_Distance(t *testing.T) { + tests := []struct { + name string + stateVal interface{} + condition ConditionInterface + want float32 + }{ + // Numeric types - EQUAL operator + { + name: "int EQUAL satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + want: 0.0, + }, + { + name: "int EQUAL below target", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + want: 50.0, + }, + { + name: "int EQUAL above target", + stateVal: 150, + condition: &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + want: 50.0, + }, + // UPPER_OR_EQUAL operator + { + name: "int UPPER_OR_EQUAL satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER_OR_EQUAL}, + want: 0.0, + }, + { + name: "int UPPER_OR_EQUAL below target", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER_OR_EQUAL}, + want: 30.0, + }, + // UPPER operator + { + name: "int UPPER satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER}, + want: 0.0, + }, + { + name: "int UPPER at target", + stateVal: 80, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER}, + want: 1.0, + }, + // LOWER_OR_EQUAL operator + { + name: "int LOWER_OR_EQUAL satisfied", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER_OR_EQUAL}, + want: 0.0, + }, + { + name: "int LOWER_OR_EQUAL above target", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER_OR_EQUAL}, + want: 20.0, + }, + // LOWER operator + { + name: "int LOWER satisfied", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER}, + want: 0.0, + }, + { + name: "int LOWER at target", + stateVal: 80, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER}, + want: 1.0, + }, + // NOT_EQUAL operator + { + name: "int NOT_EQUAL satisfied", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 100, Operator: NOT_EQUAL}, + want: 0.0, + }, + { + name: "int NOT_EQUAL not satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 100, Operator: NOT_EQUAL}, + want: 1.0, + }, + // Float64 tests + { + name: "float64 EQUAL", + stateVal: 75.5, + condition: &Condition[float64]{Key: 1, Value: 100.0, Operator: EQUAL}, + want: 24.5, + }, + // uint64 tests + { + name: "uint64 UPPER_OR_EQUAL", + stateVal: uint64(50), + condition: &Condition[uint64]{Key: 1, Value: uint64(100), Operator: UPPER_OR_EQUAL}, + want: 50.0, + }, + // int8 tests + { + name: "int8 EQUAL", + stateVal: int8(50), + condition: &Condition[int8]{Key: 1, Value: int8(100), Operator: EQUAL}, + want: 50.0, + }, + // uint8 tests + { + name: "uint8 LOWER_OR_EQUAL", + stateVal: uint8(100), + condition: &Condition[uint8]{Key: 1, Value: uint8(80), Operator: LOWER_OR_EQUAL}, + want: 20.0, + }, + // Bool tests + { + name: "bool EQUAL satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: true, Operator: EQUAL}, + want: 0.0, + }, + { + name: "bool EQUAL not satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: false, Operator: EQUAL}, + want: 1.0, + }, + { + name: "bool NOT_EQUAL satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: false, Operator: NOT_EQUAL}, + want: 0.0, + }, + { + name: "bool NOT_EQUAL not satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: true, Operator: NOT_EQUAL}, + want: 1.0, + }, + // String tests + { + name: "string EQUAL satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "test", Operator: EQUAL}, + want: 0.0, + }, + { + name: "string EQUAL not satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "other", Operator: EQUAL}, + want: 1.0, + }, + { + name: "string NOT_EQUAL satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "other", Operator: NOT_EQUAL}, + want: 0.0, + }, + { + name: "string NOT_EQUAL not satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "test", Operator: NOT_EQUAL}, + want: 1.0, + }, + // Key mismatch + { + name: "key mismatch", + stateVal: 100, + condition: &Condition[int]{Key: 99, Value: 100, Operator: EQUAL}, + want: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + + switch v := tt.stateVal.(type) { + case int: + SetState[int](&agent, 1, v) + case int8: + SetState[int8](&agent, 1, v) + case uint8: + SetState[uint8](&agent, 1, v) + case uint64: + SetState[uint64](&agent, 1, v) + case float64: + SetState[float64](&agent, 1, v) + case bool: + SetState[bool](&agent, 1, v) + case string: + SetState[string](&agent, 1, v) + } + + if len(agent.w.states) == 0 { + t.Fatal("Failed to set state") + } + + state := agent.w.states[0] + got := state.Distance(tt.condition) + + if got != tt.want { + t.Errorf("Distance() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/effect_test.go b/effect_test.go new file mode 100644 index 0000000..1a9c146 --- /dev/null +++ b/effect_test.go @@ -0,0 +1,345 @@ +package goapai + +import "testing" + +func TestEffectString_GetKey(t *testing.T) { + effect := EffectString{Key: 42, Value: "test"} + if got := effect.GetKey(); got != 42 { + t.Errorf("GetKey() = %v, want 42", got) + } +} + +func TestEffectString_Check(t *testing.T) { + tests := []struct { + name string + stateVal string + effectKey StateKey + effectVal string + operator arithmetic + want bool + }{ + { + name: "SET match", + stateVal: "hello", + effectKey: 1, + effectVal: "hello", + operator: SET, + want: true, + }, + { + name: "SET no match", + stateVal: "hello", + effectKey: 1, + effectVal: "world", + operator: SET, + want: false, + }, + { + name: "key not found", + stateVal: "hello", + effectKey: 99, + effectVal: "test", + operator: SET, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[string](&agent, 1, tt.stateVal) + + effect := EffectString{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} + got := effect.check(agent.w) + + if got != tt.want { + t.Errorf("check() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEffectString_Apply(t *testing.T) { + tests := []struct { + name string + setupState bool + stateVal string + effectKey StateKey + effectVal string + operator arithmetic + wantVal string + wantErr bool + }{ + { + name: "SET new state", + setupState: false, + effectKey: 1, + effectVal: "hello", + operator: SET, + wantVal: "hello", + wantErr: false, + }, + { + name: "SET existing state", + setupState: true, + stateVal: "old", + effectKey: 1, + effectVal: "new", + operator: SET, + wantVal: "new", + wantErr: false, + }, + { + name: "ADD concatenation", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: " world", + operator: ADD, + wantVal: "hello world", + wantErr: false, + }, + { + name: "ADD new state", + setupState: false, + effectKey: 1, + effectVal: "test", + operator: ADD, + wantVal: "test", + wantErr: false, + }, + { + name: "SUBSTRACT not allowed", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: "test", + operator: SUBSTRACT, + wantErr: true, + }, + { + name: "MULTIPLY not allowed", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: "test", + operator: MULTIPLY, + wantErr: true, + }, + { + name: "DIVIDE not allowed", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: "test", + operator: DIVIDE, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + if tt.setupState { + SetState[string](&agent, tt.effectKey, tt.stateVal) + } + + effect := EffectString{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} + err := effect.apply(&agent.w) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + idx := agent.w.states.GetIndex(tt.effectKey) + if idx < 0 { + t.Fatal("State not found after apply") + } + + gotState := agent.w.states[idx].(State[string]) + if gotState.Value != tt.wantVal { + t.Errorf("Value = %v, want %v", gotState.Value, tt.wantVal) + } + }) + } +} + +func TestEffectBool_GetKey(t *testing.T) { + effect := EffectBool{Key: 42, Value: true} + if got := effect.GetKey(); got != 42 { + t.Errorf("GetKey() = %v, want 42", got) + } +} + +func TestEffectBool_Apply_Errors(t *testing.T) { + tests := []struct { + name string + operator arithmetic + wantErr bool + }{ + { + name: "SET allowed", + operator: SET, + wantErr: false, + }, + { + name: "ADD not allowed", + operator: ADD, + wantErr: true, + }, + { + name: "SUBSTRACT not allowed", + operator: SUBSTRACT, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[bool](&agent, 1, true) + + effect := EffectBool{Key: 1, Value: false, Operator: tt.operator} + err := effect.apply(&agent.w) + + if tt.wantErr && err == nil { + t.Error("Expected error, got nil") + } + if !tt.wantErr && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestEffect_Apply_AllOperators(t *testing.T) { + tests := []struct { + key StateKey + name string + initial int + operator arithmetic + value int + wantVal int + wantErr bool + }{ + { + key: 1, + name: "SET", + initial: 100, + operator: SET, + value: 50, + wantVal: 50, + }, + { + key: 1, + name: "ADD", + initial: 100, + operator: ADD, + value: 50, + wantVal: 150, + }, + { + key: 1, + name: "SUBSTRACT", + initial: 100, + operator: SUBSTRACT, + value: 30, + wantVal: 70, + }, + { + key: 1, + name: "MULTIPLY", + initial: 10, + operator: MULTIPLY, + value: 5, + wantVal: 50, + }, + { + key: 1, + name: "DIVIDE", + initial: 100, + operator: DIVIDE, + value: 4, + wantVal: 25, + }, + { + key: 1, + name: "ADD on non-existing key", + initial: 0, // Not set + operator: ADD, + value: 50, + wantVal: 50, + }, + { + key: 1, + name: "SUBSTRACT on non-existing key", + initial: 0, // Not set + operator: SUBSTRACT, + value: 50, + wantVal: -50, + }, + { + key: 99, + name: "MULTIPLY on non-existing key error", + initial: 0, // Not set + operator: MULTIPLY, + value: 50, + wantErr: true, + }, + { + key: 99, + name: "DIVIDE on non-existing key error", + initial: 0, // Not set + operator: DIVIDE, + value: 50, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, tt.initial) + + effect := Effect[int]{Key: tt.key, Value: tt.value, Operator: tt.operator} + err := effect.apply(&agent.w) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + idx := agent.w.states.GetIndex(1) + if idx < 0 { + t.Fatal("State not found after apply") + } + + gotState := agent.w.states[idx].(State[int]) + if gotState.Value != tt.wantVal { + t.Errorf("Value = %v, want %v", gotState.Value, tt.wantVal) + } + }) + } +} + +func TestEffect_GetKey(t *testing.T) { + effect := Effect[int]{Key: 42, Value: 100} + if got := effect.GetKey(); got != 42 { + t.Errorf("GetKey() = %v, want 42", got) + } +} diff --git a/heap.go b/heap.go new file mode 100644 index 0000000..012805c --- /dev/null +++ b/heap.go @@ -0,0 +1,32 @@ +package goapai + +// nodeHeap implements heap.Interface for a min-heap of nodes based on totalCost +type nodeHeap []*node + +func (h nodeHeap) Len() int { return len(h) } + +func (h nodeHeap) Less(i, j int) bool { + return h[i].totalCost < h[j].totalCost +} + +func (h nodeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].heapIndex = i + h[j].heapIndex = j +} + +func (h *nodeHeap) Push(x interface{}) { + n := x.(*node) + n.heapIndex = len(*h) + *h = append(*h, n) +} + +func (h *nodeHeap) Pop() interface{} { + old := *h + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.heapIndex = -1 // mark as removed + *h = old[0 : n-1] + return item +} diff --git a/heap_test.go b/heap_test.go new file mode 100644 index 0000000..9403c81 --- /dev/null +++ b/heap_test.go @@ -0,0 +1,213 @@ +package goapai + +import "testing" + +func TestNodeHeap_Less(t *testing.T) { + tests := []struct { + name string + nodes nodeHeap + i, j int + wantLess bool + }{ + { + name: "first node has lower cost", + nodes: nodeHeap{ + {totalCost: 5.0}, + {totalCost: 10.0}, + }, + i: 0, + j: 1, + wantLess: true, + }, + { + name: "first node has higher cost", + nodes: nodeHeap{ + {totalCost: 15.0}, + {totalCost: 10.0}, + }, + i: 0, + j: 1, + wantLess: false, + }, + { + name: "equal costs", + nodes: nodeHeap{ + {totalCost: 10.0}, + {totalCost: 10.0}, + }, + i: 0, + j: 1, + wantLess: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.nodes.Less(tt.i, tt.j) + if got != tt.wantLess { + t.Errorf("Less(%d, %d) = %v, want %v", tt.i, tt.j, got, tt.wantLess) + } + }) + } +} + +func TestNodeHeap_Swap(t *testing.T) { + tests := []struct { + name string + initialNodes []*node + i, j int + wantI, wantJ int + wantIIndex int + wantJIndex int + }{ + { + name: "swap first and second", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + }, + i: 0, + j: 1, + wantI: 1, + wantJ: 0, + wantIIndex: 0, + wantJIndex: 1, + }, + { + name: "swap first and third", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + {totalCost: 15.0, heapIndex: 2}, + }, + i: 0, + j: 2, + wantI: 2, + wantJ: 0, + wantIIndex: 0, + wantJIndex: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := nodeHeap(tt.initialNodes) + originalI := h[tt.i] + originalJ := h[tt.j] + + h.Swap(tt.i, tt.j) + + // Verify nodes were swapped + if h[tt.i] != originalJ { + t.Errorf("After swap, h[%d] should be originalJ", tt.i) + } + if h[tt.j] != originalI { + t.Errorf("After swap, h[%d] should be originalI", tt.j) + } + + // Verify heapIndex was updated + if h[tt.i].heapIndex != tt.wantIIndex { + t.Errorf("h[%d].heapIndex = %d, want %d", tt.i, h[tt.i].heapIndex, tt.wantIIndex) + } + if h[tt.j].heapIndex != tt.wantJIndex { + t.Errorf("h[%d].heapIndex = %d, want %d", tt.j, h[tt.j].heapIndex, tt.wantJIndex) + } + }) + } +} + +func TestNodeHeap_Push(t *testing.T) { + tests := []struct { + name string + initialNodes []*node + pushNode *node + wantLen int + wantIndex int + }{ + { + name: "push to empty heap", + initialNodes: []*node{}, + pushNode: &node{totalCost: 5.0}, + wantLen: 1, + wantIndex: 0, + }, + { + name: "push to non-empty heap", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + }, + pushNode: &node{totalCost: 15.0}, + wantLen: 3, + wantIndex: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := nodeHeap(tt.initialNodes) + h.Push(tt.pushNode) + + if len(h) != tt.wantLen { + t.Errorf("Len() = %d, want %d", len(h), tt.wantLen) + } + + if tt.pushNode.heapIndex != tt.wantIndex { + t.Errorf("heapIndex = %d, want %d", tt.pushNode.heapIndex, tt.wantIndex) + } + + if h[len(h)-1] != tt.pushNode { + t.Error("Pushed node should be at the end of heap") + } + }) + } +} + +func TestNodeHeap_Pop(t *testing.T) { + tests := []struct { + name string + initialNodes []*node + wantNode *node + wantLen int + }{ + { + name: "pop from heap with one element", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + }, + wantNode: &node{totalCost: 5.0, heapIndex: -1}, + wantLen: 0, + }, + { + name: "pop from heap with multiple elements", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + {totalCost: 15.0, heapIndex: 2}, + }, + wantNode: &node{totalCost: 15.0, heapIndex: -1}, + wantLen: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := nodeHeap(tt.initialNodes) + originalLast := h[len(h)-1] + + popped := h.Pop().(*node) + + if len(h) != tt.wantLen { + t.Errorf("Len() = %d, want %d", len(h), tt.wantLen) + } + + if popped.heapIndex != -1 { + t.Errorf("Popped node heapIndex = %d, want -1", popped.heapIndex) + } + + if popped != originalLast { + t.Error("Pop should return the last element") + } + }) + } +} diff --git a/planer.go b/planer.go index 8f8e1bf..d88c6a9 100644 --- a/planer.go +++ b/planer.go @@ -32,7 +32,11 @@ func GetPlan(agent Agent, maxDepth int) (GoalName, Plan) { return "", Plan{} } - return goalName, astar(agent.states, agent.goals[goalName], agent.actions, maxDepth) + for _, state := range agent.w.states { + state.Store(&agent.w) + } + + return goalName, astar(agent.w, agent.goals[goalName], agent.actions, maxDepth) } func (agent *Agent) getPrioritizedGoalName() (GoalName, error) { @@ -54,14 +58,3 @@ func (agent *Agent) getPrioritizedGoalName() (GoalName, error) { return prioritizedGoalName, fmt.Errorf("no goal available") } } - -// GetNextAction returns the first Action required to achieve the Plan. -// -// An error is returned if no action is available, meaning the Plan is empty. -func (plan Plan) GetNextAction() (Action, error) { - if len(plan) > 0 { - return *plan[0], nil - } - - return Action{}, fmt.Errorf("no action available") -} diff --git a/planer_test.go b/planer_test.go index daf25d7..c9aa701 100644 --- a/planer_test.go +++ b/planer_test.go @@ -29,3 +29,334 @@ func TestPlan_GetTotalCost(t *testing.T) { }) } } + +// Test GetPlan +func TestGetPlan_SimpleGoal(t *testing.T) { + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 10, Operator: ADD}, + }) + + goals := Goals{ + "reach_30": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 30, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "reach_30" { + t.Errorf("Expected goal 'reach_30', got '%s'", goalName) + } + + // Plan includes root node + if len(plan) != 4 { + t.Errorf("Expected plan with 4 actions (including root), got %d", len(plan)) + } +} + +func TestGetPlan_AlreadyAtGoal(t *testing.T) { + actions := Actions{} + + goals := Goals{ + "be_at_100": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 100) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "be_at_100" { + t.Errorf("Expected goal 'be_at_100', got '%s'", goalName) + } + + // Plan includes root node even when already at goal + if len(plan) != 1 { + t.Errorf("Expected plan with 1 action (root) when already at goal, got %d actions", len(plan)) + } +} + +func TestGetPlan_NoGoalsAvailable(t *testing.T) { + actions := Actions{} + + goals := Goals{ + "impossible": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 0.0 // Zero priority means goal is not active + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "" { + t.Errorf("Expected empty goal name, got '%s'", goalName) + } + + if len(plan) != 0 { + t.Errorf("Expected empty plan when no goals available, got %d actions", len(plan)) + } +} + +func TestGetPlan_UnreachableGoal(t *testing.T) { + actions := Actions{} + // No actions to reach the goal + + goals := Goals{ + "unreachable": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "unreachable" { + t.Errorf("Expected goal 'unreachable', got '%s'", goalName) + } + + if len(plan) != 0 { + t.Errorf("Expected empty plan for unreachable goal, got %d actions", len(plan)) + } +} + +func TestGetPlan_MaxDepthExceeded(t *testing.T) { + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 1, Operator: ADD}, + }) + + goals := Goals{ + "reach_100": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + // Max depth of 5 should prevent reaching 100 + _, plan := GetPlan(agent, 5) + + if len(plan) > 5 { + t.Errorf("Expected plan to respect maxDepth of 5, got %d actions", len(plan)) + } +} + +// Test getPrioritizedGoalName +func TestGetPrioritizedGoalName_SingleGoal(t *testing.T) { + goals := Goals{ + "goal1": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + + goalName, err := agent.getPrioritizedGoalName() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if goalName != "goal1" { + t.Errorf("Expected 'goal1', got '%s'", goalName) + } +} + +func TestGetPrioritizedGoalName_MultipleGoals(t *testing.T) { + goals := Goals{ + "low_priority": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 0.5 + }, + }, + "high_priority": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 2.0 + }, + }, + "medium_priority": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + + goalName, err := agent.getPrioritizedGoalName() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if goalName != "high_priority" { + t.Errorf("Expected 'high_priority', got '%s'", goalName) + } +} + +func TestGetPrioritizedGoalName_ZeroPriority(t *testing.T) { + goals := Goals{ + "inactive": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 0.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + + _, err := agent.getPrioritizedGoalName() + + if err == nil { + t.Error("Expected error when all goals have zero priority") + } +} + +func TestGetPrioritizedGoalName_UsingSensors(t *testing.T) { + type Entity struct { + health int + } + + goals := Goals{ + "heal": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + entity := sensors.GetSensor("entity").(*Entity) + if entity.health < 50 { + return 2.0 + } + return 0.1 + }, + }, + "explore": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + entity := &Entity{health: 20} + SetSensor(&agent, "entity", entity) + + goalName, err := agent.getPrioritizedGoalName() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if goalName != "heal" { + t.Errorf("Expected 'heal' for low health, got '%s'", goalName) + } +} + +// Integration test with complex scenario +func TestGetPlan_ComplexScenario(t *testing.T) { + actions := Actions{} + + // Need wood to make fire + actions.AddAction("get_wood", 2.0, false, Conditions{}, Effects{ + EffectBool{Key: 1, Value: true, Operator: SET}, // has_wood + }) + + // Need matches to make fire + actions.AddAction("get_matches", 1.0, false, Conditions{}, Effects{ + EffectBool{Key: 2, Value: true, Operator: SET}, // has_matches + }) + + // Make fire requires both wood and matches + actions.AddAction("make_fire", 1.0, false, Conditions{ + &ConditionBool{Key: 1, Value: true, Operator: EQUAL}, // has_wood + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, // has_matches + }, Effects{ + EffectBool{Key: 3, Value: true, Operator: SET}, // has_fire + }) + + goals := Goals{ + "stay_warm": { + Conditions: Conditions{ + &ConditionBool{Key: 3, Value: true, Operator: EQUAL}, // has_fire + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[bool](&agent, 1, false) // has_wood + SetState[bool](&agent, 2, false) // has_matches + SetState[bool](&agent, 3, false) // has_fire + + goalName, plan := GetPlan(agent, 10) + + if goalName != "stay_warm" { + t.Errorf("Expected goal 'stay_warm', got '%s'", goalName) + } + + // Plan includes root + 3 actions + if len(plan) != 4 { + t.Errorf("Expected plan with 4 actions (root + 3), got %d", len(plan)) + } + + // Verify the plan makes sense (should get matches first as it's cheaper) + // Index 1 should be a resource gathering action + if plan[1].name != "get_matches" && plan[1].name != "get_wood" { + t.Errorf("Expected second action to gather resources, got '%s'", plan[1].name) + } + + // Last action should be make_fire + if plan[3].name != "make_fire" { + t.Errorf("Expected last action to be 'make_fire', got '%s'", plan[3].name) + } + + // Verify total cost (root has cost 0) + totalCost := plan.GetTotalCost() + if totalCost != 4.0 { + t.Errorf("Expected total cost 4.0, got %f", totalCost) + } +} diff --git a/state.go b/state.go index 59bc79f..c05576b 100644 --- a/state.go +++ b/state.go @@ -1,10 +1,7 @@ package goapai import ( - "encoding/binary" - "hash/fnv" - "slices" - "strconv" + "math" ) type operator uint8 @@ -18,42 +15,63 @@ const ( UPPER ) +// Numeric is a constraint that defines the numeric types supported by generic State and Condition. +// Supported types are: int8, int, uint8, uint64, and float64. type Numeric interface { ~int8 | ~int | - ~uint8 | ~uint64 | - ~float64 + ~uint8 | ~uint64 | + ~float64 } +// StateInterface defines the interface that all state types must implement. +// States represent key-value pairs in the world state, with support for hashing and distance calculation. type StateInterface interface { - Check(states states, key StateKey) bool + Check(w world, key StateKey) bool GetKey() StateKey GetValue() any + Store(w *world) + GetHash() uint64 + Hash() uint64 + Distance(condition ConditionInterface) float32 } + +// State represents a single key-value pair in the world state. +// +// States can hold numeric types (constrained by Numeric), bool, or string values. +// Each state is identified by a unique StateKey and includes a cached hash for performance. type State[T Numeric | bool | string] struct { - Key StateKey - Value T + Key StateKey // Unique identifier for this state + Value T // The state's value + hash uint64 // Cached hash value for fast comparison } +// StateKey is a compact 16-bit unsigned integer used to identify states. +// Using uint16 instead of strings reduces memory usage and improves performance. type StateKey uint16 -type statesData []StateInterface +type states []StateInterface + +type world struct { + Agent *Agent + states states + hash uint64 +} -type states struct { - Agent *Agent - data statesData - hash uint64 +// Check compares world and states2 by their hash. +func (world world) Check(world2 world) bool { + return world.hash == world2.hash } func (state State[T]) GetKey() StateKey { return state.Key } -func (state State[T]) Check(states states, key StateKey) bool { - k := states.data.GetIndex(key) +func (state State[T]) Check(w world, key StateKey) bool { + k := w.states.GetIndex(key) if k < 0 { return false } - s := states.data[k] + s := w.states[k] if agentState, ok := s.(State[T]); ok { if agentState.Value == state.Value { return true @@ -67,118 +85,162 @@ func (state State[T]) GetValue() any { return state.Value } -// Check compares states and states2 by their hash. -func (states states) Check(states2 states) bool { - return states.hash == states2.hash +func (state State[T]) Store(w *world) { + oldHash := state.hash + state.hash = state.Hash() + w.hash = updateHashIncremental(w.hash, oldHash, state.hash) + k := w.states.GetIndex(state.Key) + if k < 0 { + w.states = append(w.states, state) + } else { + w.states[k] = state + } +} + +func (state State[T]) GetHash() uint64 { + return state.hash } -func (statesData statesData) GetIndex(stateKey StateKey) int { - for k, stateData := range statesData { - if stateData.GetKey() == stateKey { - return k +// Hash returns a unique hash for this state using a fast multiplicative hash +// It implements a fast inline multiplicative hash +// Uses prime multipliers for good distribution without allocations +func (state State[T]) Hash() uint64 { + const ( + prime1 uint64 = 11400714819323198485 // Large prime for key + prime2 uint64 = 14029467366897019727 // Second prime for value + ) + + // Start with key + hash := uint64(state.Key) * prime1 + + // Mix in value based on type + switch v := any(state.Value).(type) { + case int8: + hash ^= uint64(v) * prime2 + case int: + hash ^= uint64(v) * prime2 + case uint8: + hash ^= uint64(v) * prime2 + case uint64: + hash ^= v * prime2 + case float64: + hash ^= math.Float64bits(v) * prime2 + case bool: + if v { + hash ^= prime2 + } + case string: + // For strings, hash each byte + for i := 0; i < len(v); i++ { + hash = hash*prime2 ^ uint64(v[i]) } } - return -1 + return hash } -func (statesData statesData) sort() { - slices.SortFunc(statesData, func(a, b StateInterface) int { - if a.GetKey() > b.GetKey() { - return 1 - } else if a.GetKey() < b.GetKey() { - return -1 - } +// updateHashIncremental updates a hash by removing old state and adding new state +func updateHashIncremental(currentHash uint64, oldStateHash, newStateHash uint64) uint64 { + currentHash ^= oldStateHash // Remove old + currentHash ^= newStateHash // Add new - return 0 - }) + return currentHash } -func (statesData statesData) hashStates() uint64 { - hash := fnv.New64() - - buf := make([]byte, binary.MaxVarintLen64) - for _, data := range statesData { - n := binary.PutVarint(buf, int64(data.GetKey())) - hash.Write(buf[:n]) - hash.Write([]byte(":")) - - switch v := data.GetValue().(type) { - case int8: - n = binary.PutVarint(buf, int64(v)) - hash.Write(buf[:n]) - case int: - n = binary.PutVarint(buf, int64(v)) - hash.Write(buf[:n]) - case uint8: - n = binary.PutUvarint(buf, uint64(v)) - hash.Write(buf[:n]) - case uint64: - n = binary.PutUvarint(buf, v) - hash.Write(buf[:n]) - case float64: - hash.Write([]byte(strconv.FormatFloat(v, 'f', -1, 64))) - case string: - hash.Write([]byte(v)) - case []byte: - hash.Write(v) - default: - binary.Write(hash, binary.LittleEndian, data.GetValue()) +func (statesData states) GetIndex(stateKey StateKey) int { + for k, stateData := range statesData { + if stateData.GetKey() == stateKey { + return k } - hash.Write([]byte(";")) } - return hash.Sum64() + return -1 } +// Sensor is an alias for any type, used to store external data accessed by goal priority +// functions and procedural conditions. type Sensor any + +// Sensors is a map of sensor names to their values, providing external data to the agent +// without duplicating it in the world state during planning. type Sensors map[string]Sensor +// GetSensor retrieves a sensor value by name. +// Returns nil if the sensor doesn't exist. func (sensors Sensors) GetSensor(name string) Sensor { return sensors[name] } +// ConditionInterface defines the interface that all condition types must implement. +// Conditions are preconditions that must be satisfied for actions or goals. type ConditionInterface interface { GetKey() StateKey - Check(states states) bool + Check(w world) bool } +// ConditionFn represents a procedural condition that evaluates against sensor data. +// +// Unlike state-based conditions, ConditionFn uses a custom function to check sensors. +// The result is cached after the first evaluation to avoid redundant computation during planning. +// +// Example: +// +// condition := &ConditionFn{ +// Key: 100, +// CheckFn: func(sensors Sensors) bool { +// health := sensors["health"].(int) +// return health < 50 +// }, +// } type ConditionFn struct { - Key StateKey - CheckFn func(sensors Sensors) bool - resolved bool - valid bool + Key StateKey // Unique identifier for this condition + CheckFn func(sensors Sensors) bool // Function that evaluates the condition + resolved bool // Whether the condition has been evaluated + valid bool // Cached result of the evaluation } func (conditionFn *ConditionFn) GetKey() StateKey { return conditionFn.Key } -func (conditionFn *ConditionFn) Check(states states) bool { +func (conditionFn *ConditionFn) Check(w world) bool { if !conditionFn.resolved { - conditionFn.valid = conditionFn.CheckFn(states.Agent.sensors) + conditionFn.valid = conditionFn.CheckFn(w.Agent.sensors) conditionFn.resolved = true } return conditionFn.valid } +// Condition represents a numeric state-based condition with comparison operators. +// +// Conditions check if a state value satisfies a comparison (EQUAL, UPPER, LOWER, etc.) +// against a target value. Supported types are constrained by the Numeric interface. +// +// Example: +// +// // Check if state key 1 is greater than or equal to 100 +// condition := &Condition[int]{ +// Key: 1, +// Value: 100, +// Operator: UPPER_OR_EQUAL, +// } type Condition[T Numeric] struct { - Key StateKey - Value T - Operator operator + Key StateKey // State key to check + Value T // Target value to compare against + Operator operator // Comparison operator (EQUAL, UPPER, LOWER, etc.) } func (condition *Condition[T]) GetKey() StateKey { return condition.Key } -func (condition *Condition[T]) Check(states states) bool { - k := states.data.GetIndex(condition.Key) +func (condition *Condition[T]) Check(w world) bool { + k := w.states.GetIndex(condition.Key) if k < 0 { return false } - s := states.data[k] + s := w.states[k] if state, ok := s.(State[T]); ok { switch condition.Operator { case EQUAL: @@ -211,22 +273,35 @@ func (condition *Condition[T]) Check(states states) bool { return false } +// ConditionBool represents a boolean state-based condition. +// +// Only EQUAL and NOT_EQUAL operators are supported for boolean conditions. +// Other operators will cause Check to return false. +// +// Example: +// +// // Check if state key 2 is true +// condition := &ConditionBool{ +// Key: 2, +// Value: true, +// Operator: EQUAL, +// } type ConditionBool struct { - Key StateKey - Value bool - Operator operator + Key StateKey // State key to check + Value bool // Target boolean value + Operator operator // Allowed: EQUAL, NOT_EQUAL } func (conditionBool *ConditionBool) GetKey() StateKey { return conditionBool.Key } -func (conditionBool *ConditionBool) Check(states states) bool { - k := states.data.GetIndex(conditionBool.Key) +func (conditionBool *ConditionBool) Check(w world) bool { + k := w.states.GetIndex(conditionBool.Key) if k < 0 { return false } - s := states.data[k] + s := w.states[k] if state, ok := s.(State[bool]); ok { switch conditionBool.Operator { case EQUAL: @@ -245,22 +320,36 @@ func (conditionBool *ConditionBool) Check(states states) bool { return false } +// ConditionString represents a string state-based condition. +// +// Only EQUAL and NOT_EQUAL operators are supported for string conditions. +// Other operators will cause Check to return false. +// +// Example: +// +// // Check if state key 3 equals "ready" +// condition := &ConditionString{ +// Key: 3, +// Value: "ready", +// Operator: EQUAL, +// } type ConditionString struct { - Key StateKey - Value string - Operator operator + Key StateKey // State key to check + Value string // Target string value + Operator operator // Allowed: EQUAL, NOT_EQUAL } +// GetKey returns the state key that this condition checks. func (conditionString *ConditionString) GetKey() StateKey { return conditionString.Key } -func (conditionString *ConditionString) Check(states states) bool { - k := states.data.GetIndex(conditionString.Key) +func (conditionString *ConditionString) Check(w world) bool { + k := w.states.GetIndex(conditionString.Key) if k < 0 { return false } - s := states.data[k] + s := w.states[k] if state, ok := s.(State[string]); ok { switch conditionString.Operator { case EQUAL: @@ -279,11 +368,14 @@ func (conditionString *ConditionString) Check(states states) bool { return false } +// Conditions is a collection of ConditionInterface implementations that must all be satisfied. type Conditions []ConditionInterface -func (conditions Conditions) Check(states states) bool { +// Check returns true if all conditions in the collection are satisfied in the given world state. +// Returns true for an empty condition list (vacuous truth). +func (conditions Conditions) Check(w world) bool { for _, condition := range conditions { - if !condition.Check(states) { + if !condition.Check(w) { return false } } diff --git a/state_test.go b/state_test.go new file mode 100644 index 0000000..c024f8f --- /dev/null +++ b/state_test.go @@ -0,0 +1,376 @@ +package goapai + +import "testing" + +func TestState_Operations(t *testing.T) { + tests := []struct { + name string + testFunc func(*testing.T) + }{ + { + name: "GetKey", + testFunc: func(t *testing.T) { + state := State[int]{Key: 42, Value: 100} + if state.GetKey() != 42 { + t.Errorf("Expected key 42, got %d", state.GetKey()) + } + }, + }, + { + name: "GetValue", + testFunc: func(t *testing.T) { + state := State[int]{Key: 1, Value: 100} + if state.GetValue().(int) != 100 { + t.Errorf("Expected value 100, got %v", state.GetValue()) + } + }, + }, + { + name: "Check match", + testFunc: func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + state := State[int]{Key: 1, Value: 100} + if !state.Check(agent.w, 1) { + t.Error("Expected state to match") + } + }, + }, + { + name: "Check no match", + testFunc: func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + wrongState := State[int]{Key: 1, Value: 200} + if wrongState.Check(agent.w, 1) { + t.Error("Expected state not to match") + } + }, + }, + { + name: "Check key not found", + testFunc: func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + state := State[int]{Key: 99, Value: 100} + + if state.Check(agent.w, 99) { + t.Error("Expected false for non-existent key") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.testFunc(t) + }) + } +} + +func TestCondition_Operators(t *testing.T) { + tests := []struct { + name string + stateVal int + condVal int + operator operator + wantMatch bool + }{ + {"equal match", 100, 100, EQUAL, true}, + {"equal no match", 100, 99, EQUAL, false}, + {"not equal match", 100, 99, NOT_EQUAL, true}, + {"not equal no match", 100, 100, NOT_EQUAL, false}, + {"lower true", 50, 100, LOWER, true}, + {"lower false equal", 50, 50, LOWER, false}, + {"lower false greater", 100, 50, LOWER, false}, + {"lower or equal true less", 50, 100, LOWER_OR_EQUAL, true}, + {"lower or equal true equal", 50, 50, LOWER_OR_EQUAL, true}, + {"lower or equal false", 100, 50, LOWER_OR_EQUAL, false}, + {"upper true", 100, 50, UPPER, true}, + {"upper false equal", 100, 100, UPPER, false}, + {"upper false less", 50, 100, UPPER, false}, + {"upper or equal true greater", 100, 50, UPPER_OR_EQUAL, true}, + {"upper or equal true equal", 100, 100, UPPER_OR_EQUAL, true}, + {"upper or equal false", 50, 100, UPPER_OR_EQUAL, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, tt.stateVal) + + condition := Condition[int]{Key: 1, Value: tt.condVal, Operator: tt.operator} + if got := condition.Check(agent.w); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v for state=%d, cond=%d, op=%v", + got, tt.wantMatch, tt.stateVal, tt.condVal, tt.operator) + } + }) + } +} + +func TestCondition_KeyNotFound(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + condition := Condition[int]{Key: 99, Value: 100, Operator: EQUAL} + + if condition.Check(agent.w) { + t.Error("Expected condition to fail when key not found") + } +} + +func TestConditionBool(t *testing.T) { + tests := []struct { + name string + stateVal bool + condVal bool + operator operator + wantMatch bool + }{ + {"equal true match", true, true, EQUAL, true}, + {"equal false match", false, false, EQUAL, true}, + {"equal no match", true, false, EQUAL, false}, + {"not equal match", true, false, NOT_EQUAL, true}, + {"not equal no match", true, true, NOT_EQUAL, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[bool](&agent, 1, tt.stateVal) + + condition := ConditionBool{Key: 1, Value: tt.condVal, Operator: tt.operator} + if got := condition.Check(agent.w); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } + + t.Run("key not found", func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + condition := ConditionBool{Key: 99, Value: true, Operator: EQUAL} + + if condition.Check(agent.w) { + t.Error("Expected condition to fail when key not found") + } + }) +} + +func TestConditionString(t *testing.T) { + tests := []struct { + name string + stateVal string + condVal string + operator operator + wantMatch bool + }{ + {"equal match", "test", "test", EQUAL, true}, + {"equal no match", "test", "other", EQUAL, false}, + {"not equal match", "test", "other", NOT_EQUAL, true}, + {"not equal no match", "test", "test", NOT_EQUAL, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[string](&agent, 1, tt.stateVal) + + condition := ConditionString{Key: 1, Value: tt.condVal, Operator: tt.operator} + if got := condition.Check(agent.w); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } +} + +func TestConditionFn(t *testing.T) { + tests := []struct { + name string + sensorVal int + threshold int + wantResult bool + }{ + {"above threshold", 100, 50, true}, + {"below threshold", 30, 50, false}, + {"at threshold", 50, 50, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetSensor(&agent, "value", tt.sensorVal) + + condition := &ConditionFn{ + Key: 1, + CheckFn: func(sensors Sensors) bool { + return sensors.GetSensor("value").(int) > tt.threshold + }, + } + + if got := condition.Check(agent.w); got != tt.wantResult { + t.Errorf("Check() = %v, want %v", got, tt.wantResult) + } + + // Test caching + if !condition.resolved { + t.Error("Expected condition to be marked as resolved") + } + + // Call again to test cache + if got := condition.Check(agent.w); got != tt.wantResult { + t.Error("Expected cached result to match") + } + }) + } +} + +func TestConditions_Check(t *testing.T) { + tests := []struct { + name string + setup func(*Agent) + conditions Conditions + wantMatch bool + }{ + { + name: "all match", + setup: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[bool](a, 2, true) + }, + conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + wantMatch: true, + }, + { + name: "one fails", + setup: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[bool](a, 2, true) + }, + conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: false, Operator: EQUAL}, + }, + wantMatch: false, + }, + { + name: "empty conditions", + setup: func(a *Agent) {}, + conditions: Conditions{}, + wantMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + tt.setup(&agent) + + if got := tt.conditions.Check(agent.w); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } +} + +func TestSensors_GetSensor(t *testing.T) { + sensors := Sensors{ + "test": 42, + } + + value := sensors.GetSensor("test") + if value.(int) != 42 { + t.Errorf("Expected 42, got %v", value) + } +} + +func TestState_Hash(t *testing.T) { + tests := []struct { + name string + state1 StateInterface + state2 StateInterface + wantSame bool + }{ + { + name: "same int world", + state1: State[int]{Key: 1, Value: 100}, + state2: State[int]{Key: 1, Value: 100}, + wantSame: true, + }, + { + name: "different int values", + state1: State[int]{Key: 1, Value: 100}, + state2: State[int]{Key: 1, Value: 200}, + wantSame: false, + }, + { + name: "different keys", + state1: State[int]{Key: 1, Value: 100}, + state2: State[int]{Key: 2, Value: 100}, + wantSame: false, + }, + { + name: "same bool world", + state1: State[bool]{Key: 1, Value: true}, + state2: State[bool]{Key: 1, Value: true}, + wantSame: true, + }, + { + name: "different bool values", + state1: State[bool]{Key: 1, Value: true}, + state2: State[bool]{Key: 1, Value: false}, + wantSame: false, + }, + { + name: "same string world", + state1: State[string]{Key: 1, Value: "test"}, + state2: State[string]{Key: 1, Value: "test"}, + wantSame: true, + }, + { + name: "different string values", + state1: State[string]{Key: 1, Value: "test"}, + state2: State[string]{Key: 1, Value: "other"}, + wantSame: false, + }, + { + name: "same float64 world", + state1: State[float64]{Key: 1, Value: 3.14}, + state2: State[float64]{Key: 1, Value: 3.14}, + wantSame: true, + }, + { + name: "different float64 values", + state1: State[float64]{Key: 1, Value: 3.14}, + state2: State[float64]{Key: 1, Value: 2.71}, + wantSame: false, + }, + { + name: "same uint64 world", + state1: State[uint64]{Key: 1, Value: 12345}, + state2: State[uint64]{Key: 1, Value: 12345}, + wantSame: true, + }, + { + name: "different uint64 values", + state1: State[uint64]{Key: 1, Value: 12345}, + state2: State[uint64]{Key: 1, Value: 54321}, + wantSame: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash1 := tt.state1.Hash() + hash2 := tt.state2.Hash() + + if tt.wantSame && hash1 != hash2 { + t.Errorf("Expected same hash, got %d and %d", hash1, hash2) + } + if !tt.wantSame && hash1 == hash2 { + t.Errorf("Expected different hashes, both got %d", hash1) + } + }) + } +}