From 1ebe20ec234d50ef477f082cf79b69bdd711691a Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Wed, 29 Oct 2025 07:32:30 +0100 Subject: [PATCH 01/18] feat: unit tests coverage 85% --- action_test.go | 422 ++++++++++++++++++++++++++++++++++++++ agent_test.go | 171 ++++++++++++++++ astar_test.go | 538 +++++++++++++++++++++++++++++++++++++++++++++++++ planer_test.go | 375 ++++++++++++++++++++++++++++++++++ state_test.go | 438 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1944 insertions(+) create mode 100644 action_test.go create mode 100644 agent_test.go create mode 100644 astar_test.go create mode 100644 state_test.go diff --git a/action_test.go b/action_test.go new file mode 100644 index 0000000..51c7c22 --- /dev/null +++ b/action_test.go @@ -0,0 +1,422 @@ +package goapai + +import "testing" + +// Test Actions.AddAction +func TestActions_AddAction(t *testing.T) { + actions := Actions{} + + actions.AddAction("test", 1.5, false, Conditions{}, Effects{}) + + if len(actions) != 1 { + t.Errorf("Expected 1 action, got %d", len(actions)) + } + + action := actions[0] + if action.name != "test" { + t.Errorf("Expected name 'test', got '%s'", action.name) + } + if action.cost != 1.5 { + t.Errorf("Expected cost 1.5, got %f", action.cost) + } + if action.repeatable { + t.Error("Expected repeatable to be false") + } +} + +func TestActions_AddAction_Multiple(t *testing.T) { + actions := Actions{} + + actions.AddAction("action1", 1.0, true, Conditions{}, Effects{}) + actions.AddAction("action2", 2.0, false, Conditions{}, Effects{}) + + if len(actions) != 2 { + t.Errorf("Expected 2 actions, got %d", len(actions)) + } +} + +// Test Action.GetName +func TestAction_GetName(t *testing.T) { + actions := Actions{} + actions.AddAction("my_action", 1.0, false, Conditions{}, Effects{}) + + name := actions[0].GetName() + if name != "my_action" { + t.Errorf("Expected 'my_action', got '%s'", name) + } +} + +// Test Action.GetEffects +func TestAction_GetEffects(t *testing.T) { + effects := Effects{ + Effect[int]{Key: 1, Value: 10, Operator: SET}, + } + + actions := Actions{} + actions.AddAction("test", 1.0, false, Conditions{}, effects) + + retrieved := actions[0].GetEffects() + if len(retrieved) != 1 { + t.Errorf("Expected 1 effect, got %d", len(retrieved)) + } +} + +// Test Effect[T Numeric] operations +func TestEffect_Check_Match(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + effect := Effect[int]{Key: 1, Value: 100, Operator: SET} + if !effect.check(agent.states) { + t.Error("Expected effect to match state") + } +} + +func TestEffect_Check_NoMatch(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + effect := Effect[int]{Key: 1, Value: 200, Operator: SET} + if effect.check(agent.states) { + t.Error("Expected effect not to match state") + } +} + +func TestEffect_Check_NonSetOperator(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + effect := Effect[int]{Key: 1, Value: 100, Operator: ADD} + if effect.check(agent.states) { + t.Error("Expected non-SET operator to always return false") + } +} + +func TestEffect_Apply_Set(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 100}, + } + + effect := Effect[int]{Key: 1, Value: 200, Operator: SET} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[int]).Value != 200 { + t.Errorf("Expected value 200, got %d", data[0].(State[int]).Value) + } +} + +func TestEffect_Apply_Add(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 100}, + } + + effect := Effect[int]{Key: 1, Value: 50, Operator: ADD} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[int]).Value != 150 { + t.Errorf("Expected value 150, got %d", data[0].(State[int]).Value) + } +} + +func TestEffect_Apply_Subtract(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 100}, + } + + effect := Effect[int]{Key: 1, Value: 30, Operator: SUBSTRACT} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[int]).Value != 70 { + t.Errorf("Expected value 70, got %d", data[0].(State[int]).Value) + } +} + +func TestEffect_Apply_Multiply(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 10}, + } + + effect := Effect[int]{Key: 1, Value: 5, Operator: MULTIPLY} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[int]).Value != 50 { + t.Errorf("Expected value 50, got %d", data[0].(State[int]).Value) + } +} + +func TestEffect_Apply_Divide(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 100}, + } + + effect := Effect[int]{Key: 1, Value: 4, Operator: DIVIDE} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[int]).Value != 25 { + t.Errorf("Expected value 25, got %d", data[0].(State[int]).Value) + } +} + +func TestEffect_Apply_NewKey(t *testing.T) { + data := statesData{} + + effect := Effect[int]{Key: 1, Value: 42, Operator: SET} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Note: apply modifies the slice in place but the original reference doesn't change + // This is actually expected behavior - effects.apply() returns a new slice +} + +// Test EffectBool +func TestEffectBool_Check_Match(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[bool](&agent, 1, true) + + effect := EffectBool{Key: 1, Value: true, Operator: SET} + if !effect.check(agent.states) { + t.Error("Expected effect to match state") + } +} + +func TestEffectBool_Check_NoMatch(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[bool](&agent, 1, true) + + effect := EffectBool{Key: 1, Value: false, Operator: SET} + if effect.check(agent.states) { + t.Error("Expected effect not to match state") + } +} + +func TestEffectBool_Apply_Set(t *testing.T) { + data := statesData{ + State[bool]{Key: 1, Value: false}, + } + + effect := EffectBool{Key: 1, Value: true, Operator: SET} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[bool]).Value != true { + t.Error("Expected value to be true") + } +} + +func TestEffectBool_Apply_InvalidOperator(t *testing.T) { + data := statesData{ + State[bool]{Key: 1, Value: true}, + } + + effect := EffectBool{Key: 1, Value: false, Operator: ADD} + err := effect.apply(data) + + if err == nil { + t.Error("Expected error for invalid operator on bool") + } +} + +func TestEffectBool_Apply_NewKey(t *testing.T) { + data := statesData{} + + effect := EffectBool{Key: 1, Value: true, Operator: SET} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Note: apply modifies the slice in place but the original reference doesn't change + // This is actually expected behavior - effects.apply() returns a new slice +} + +// Test EffectString +func TestEffectString_Check_Match(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[string](&agent, 1, "test") + + effect := EffectString{Key: 1, Value: "test", Operator: SET} + if !effect.check(agent.states) { + t.Error("Expected effect to match state") + } +} + +func TestEffectString_Check_NoMatch(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[string](&agent, 1, "test") + + effect := EffectString{Key: 1, Value: "other", Operator: SET} + if effect.check(agent.states) { + t.Error("Expected effect not to match state") + } +} + +func TestEffectString_Apply_Set(t *testing.T) { + data := statesData{ + State[string]{Key: 1, Value: "old"}, + } + + effect := EffectString{Key: 1, Value: "new", Operator: SET} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[string]).Value != "new" { + t.Errorf("Expected value 'new', got '%s'", data[0].(State[string]).Value) + } +} + +func TestEffectString_Apply_Add_Concatenate(t *testing.T) { + data := statesData{ + State[string]{Key: 1, Value: "hello"}, + } + + effect := EffectString{Key: 1, Value: " world", Operator: ADD} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if data[0].(State[string]).Value != "hello world" { + t.Errorf("Expected 'hello world', got '%s'", data[0].(State[string]).Value) + } +} + +func TestEffectString_Apply_InvalidOperator(t *testing.T) { + data := statesData{ + State[string]{Key: 1, Value: "test"}, + } + + effect := EffectString{Key: 1, Value: "x", Operator: MULTIPLY} + err := effect.apply(data) + + if err == nil { + t.Error("Expected error for invalid operator on string") + } +} + +func TestEffectString_Apply_NewKey(t *testing.T) { + data := statesData{} + + effect := EffectString{Key: 1, Value: "new", Operator: SET} + err := effect.apply(data) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Note: apply modifies the slice in place but the original reference doesn't change + // This is actually expected behavior - effects.apply() returns a new slice +} + +// Test Effects (slice) operations +func TestEffects_SatisfyStates_AllMatch(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, true) + + effects := Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + EffectBool{Key: 2, Value: true, Operator: SET}, + } + + if !effects.satisfyStates(agent.states) { + t.Error("Expected effects to satisfy states") + } +} + +func TestEffects_SatisfyStates_OneFails(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, true) + + effects := Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + EffectBool{Key: 2, Value: false, Operator: SET}, + } + + if effects.satisfyStates(agent.states) { + t.Error("Expected effects not to satisfy states when one doesn't match") + } +} + +func TestEffects_Apply(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, false) + + effects := Effects{ + Effect[int]{Key: 1, Value: 50, Operator: ADD}, + EffectBool{Key: 2, Value: true, Operator: SET}, + } + + newData, err := effects.apply(agent.states) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(newData) != 2 { + t.Errorf("Expected 2 states, got %d", len(newData)) + } + + // Check int value + intIdx := newData.GetIndex(1) + if intIdx < 0 { + t.Error("Expected to find key 1") + } + if newData[intIdx].(State[int]).Value != 150 { + t.Errorf("Expected int value 150, got %d", newData[intIdx].(State[int]).Value) + } + + // Check bool value + boolIdx := newData.GetIndex(2) + if boolIdx < 0 { + t.Error("Expected to find key 2") + } + if newData[boolIdx].(State[bool]).Value != true { + t.Error("Expected bool value true") + } +} + +func TestEffects_Apply_Error(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + effects := Effects{ + Effect[float64]{Key: 1, Value: 50.0, Operator: SET}, // Type mismatch + } + + _, err := effects.apply(agent.states) + if err == nil { + t.Error("Expected error for type mismatch") + } +} diff --git a/agent_test.go b/agent_test.go new file mode 100644 index 0000000..e8905ee --- /dev/null +++ b/agent_test.go @@ -0,0 +1,171 @@ +package goapai + +import "testing" + +func TestCreateAgent(t *testing.T) { + goals := Goals{ + "test_goal": { + Conditions: Conditions{ + &ConditionBool{Key: 1, Value: true}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + actions := Actions{} + actions.AddAction("test_action", 1.0, false, Conditions{}, Effects{}) + + agent := CreateAgent(goals, actions) + + if len(agent.actions) != 1 { + t.Errorf("Expected 1 action, got %d", len(agent.actions)) + } + + if len(agent.goals) != 1 { + t.Errorf("Expected 1 goal, got %d", len(agent.goals)) + } + + if agent.sensors == nil { + t.Error("Expected sensors to be initialized") + } + + if agent.states.Agent == nil { + t.Error("Expected states.Agent to be non-nil") + } +} + +func TestSetState(t *testing.T) { + tests := []struct { + name string + setupFunc func(*Agent) + checkFunc func(*testing.T, Agent) + }{ + { + name: "int state", + setupFunc: func(a *Agent) { + SetState[int](a, 1, 42) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.states.data) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.states.data)) + } + state := a.states.data[0].(State[int]) + if state.Key != 1 || state.Value != 42 { + t.Errorf("Expected key=1, value=42, got key=%d, value=%d", state.Key, state.Value) + } + }, + }, + { + name: "bool state", + setupFunc: func(a *Agent) { + SetState[bool](a, 2, true) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.states.data) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.states.data)) + } + state := a.states.data[0].(State[bool]) + if state.Key != 2 || state.Value != true { + t.Errorf("Expected key=2, value=true, got key=%d, value=%v", state.Key, state.Value) + } + }, + }, + { + name: "string state", + setupFunc: func(a *Agent) { + SetState[string](a, 3, "test") + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.states.data) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.states.data)) + } + state := a.states.data[0].(State[string]) + if state.Key != 3 || state.Value != "test" { + t.Errorf("Expected key=3, value='test', got key=%d, value='%s'", state.Key, state.Value) + } + }, + }, + { + name: "multiple states", + setupFunc: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[bool](a, 2, false) + SetState[string](a, 3, "hello") + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.states.data) != 3 { + t.Errorf("Expected 3 states, got %d", len(a.states.data)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + tt.setupFunc(&agent) + tt.checkFunc(t, agent) + }) + } +} + +func TestSetSensor(t *testing.T) { + type TestEntity struct { + health int + } + + tests := []struct { + name string + setupFunc func(*Agent) + checkFunc func(*testing.T, Agent) + }{ + { + name: "single sensor", + setupFunc: func(a *Agent) { + entity := &TestEntity{health: 100} + SetSensor(a, "entity", entity) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.sensors) != 1 { + t.Errorf("Expected 1 sensor, got %d", len(a.sensors)) + } + retrieved := a.sensors.GetSensor("entity").(*TestEntity) + if retrieved.health != 100 { + t.Errorf("Expected health 100, got %d", retrieved.health) + } + }, + }, + { + name: "multiple sensors", + setupFunc: func(a *Agent) { + SetSensor(a, "sensor1", "value1") + SetSensor(a, "sensor2", 42) + SetSensor(a, "sensor3", true) + }, + checkFunc: func(t *testing.T, a Agent) { + if len(a.sensors) != 3 { + t.Errorf("Expected 3 sensors, got %d", len(a.sensors)) + } + if a.sensors.GetSensor("sensor1").(string) != "value1" { + t.Error("sensor1 value mismatch") + } + if a.sensors.GetSensor("sensor2").(int) != 42 { + t.Error("sensor2 value mismatch") + } + if a.sensors.GetSensor("sensor3").(bool) != true { + t.Error("sensor3 value mismatch") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + tt.setupFunc(&agent) + tt.checkFunc(t, agent) + }) + } +} diff --git a/astar_test.go b/astar_test.go new file mode 100644 index 0000000..642c438 --- /dev/null +++ b/astar_test.go @@ -0,0 +1,538 @@ +package goapai + +import ( + "slices" + "testing" +) + +// Test getImpactingActions +func TestGetImpactingActions(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, false) + + actions := Actions{} + // Action that changes state - should be included + actions.AddAction("change", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 200, Operator: SET}, + }) + // Action with effects matching current state - should be excluded + actions.AddAction("no_change", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + EffectBool{Key: 2, Value: false, Operator: SET}, + }) + + impacting := getImpactingActions(agent.states, actions) + + if len(impacting) != 1 { + t.Errorf("Expected 1 impacting action, got %d", len(impacting)) + } + + if impacting[0].name != "change" { + t.Errorf("Expected 'change' action, got '%s'", impacting[0].name) + } +} + +func TestGetImpactingActions_AllImpacting(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + actions := Actions{} + actions.AddAction("action1", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 200, Operator: SET}, + }) + actions.AddAction("action2", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 50, Operator: SET}, + }) + + impacting := getImpactingActions(agent.states, actions) + + if len(impacting) != 2 { + t.Errorf("Expected 2 impacting actions, got %d", len(impacting)) + } +} + +// Test getLessCostlyNodeKey +func TestGetLessCostlyNodeKey(t *testing.T) { + nodes := []*node{ + {totalCost: 10.0}, + {totalCost: 5.0}, + {totalCost: 15.0}, + } + + key := getLessCostlyNodeKey(nodes) + if key != 1 { + t.Errorf("Expected key 1 (lowest cost), got %d", key) + } +} + +func TestGetLessCostlyNodeKey_Empty(t *testing.T) { + nodes := []*node{} + + key := getLessCostlyNodeKey(nodes) + if key != -1 { + t.Errorf("Expected -1 for empty list, got %d", key) + } +} + +// Test fetchNode +func TestFetchNode_Found(t *testing.T) { + agent1 := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent1, 1, 100) + agent1.states.data.sort() + agent1.states.hash = agent1.states.data.hashStates() + + agent2 := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent2, 1, 100) + agent2.states.data.sort() + agent2.states.hash = agent2.states.data.hashStates() + + nodes := []*node{ + {states: agent1.states}, + {states: agent2.states}, + } + + key, found := fetchNode(nodes, agent1.states) + if !found { + t.Error("Expected to find node") + } + if key != 0 { + t.Errorf("Expected key 0, got %d", key) + } +} + +func TestFetchNode_NotFound(t *testing.T) { + agent1 := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent1, 1, 100) + agent1.states.data.sort() + agent1.states.hash = agent1.states.data.hashStates() + + agent2 := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent2, 1, 200) + agent2.states.data.sort() + agent2.states.hash = agent2.states.data.hashStates() + + nodes := []*node{ + {states: agent1.states}, + } + + _, found := fetchNode(nodes, agent2.states) + if found { + t.Error("Expected not to find node") + } +} + +// Test buildPlanFromNode +func TestBuildPlanFromNode(t *testing.T) { + action1 := &Action{name: "action1", cost: 1.0} + action2 := &Action{name: "action2", cost: 1.0} + action3 := &Action{name: "action3", cost: 1.0} + + // Build a chain: nil -> node1 -> node2 -> node3 + node1 := &node{Action: action1, parentNode: nil, depth: 1} + node2 := &node{Action: action2, parentNode: node1, depth: 2} + node3 := &node{Action: action3, parentNode: node2, depth: 3} + + plan := buildPlanFromNode(node3) + + if len(plan) != 3 { + t.Errorf("Expected plan length 3, got %d", len(plan)) + } + + if plan[0].name != "action1" { + t.Errorf("Expected first action to be 'action1', got '%s'", plan[0].name) + } + if plan[1].name != "action2" { + t.Errorf("Expected second action to be 'action2', got '%s'", plan[1].name) + } + if plan[2].name != "action3" { + t.Errorf("Expected third action to be 'action3', got '%s'", plan[2].name) + } +} + +func TestBuildPlanFromNode_SingleNode(t *testing.T) { + action1 := &Action{name: "action1", cost: 1.0} + node1 := &node{Action: action1, parentNode: nil, depth: 1} + + plan := buildPlanFromNode(node1) + + if len(plan) != 1 { + t.Errorf("Expected plan length 1, got %d", len(plan)) + } +} + +// Test simulateActionState +func TestSimulateActionState(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + action := &Action{ + effects: Effects{ + Effect[int]{Key: 1, Value: 50, Operator: ADD}, + }, + } + + newStates, ok := simulateActionState(action, agent.states) + if !ok { + t.Error("Expected simulation to succeed") + } + + idx := newStates.data.GetIndex(1) + if idx < 0 { + t.Error("Expected to find key 1 in new states") + } + + if newStates.data[idx].(State[int]).Value != 150 { + t.Errorf("Expected value 150, got %d", newStates.data[idx].(State[int]).Value) + } +} + +func TestSimulateActionState_NoChange(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + action := &Action{ + effects: Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }, + } + + _, ok := simulateActionState(action, agent.states) + if ok { + t.Error("Expected simulation to fail when effects match current state") + } +} + +// Test allowedRepetition +func TestAllowedRepetition_Repeatable(t *testing.T) { + action := &Action{name: "test", repeatable: true} + parentNode := &node{Action: action} + + if !allowedRepetition(action, parentNode) { + t.Error("Expected repeatable action to be allowed") + } +} + +func TestAllowedRepetition_NonRepeatable_NotUsed(t *testing.T) { + action1 := &Action{name: "action1", repeatable: false} + action2 := &Action{name: "action2", repeatable: false} + + node1 := &node{Action: action2, parentNode: nil} + + if !allowedRepetition(action1, node1) { + t.Error("Expected non-repeated action to be allowed") + } +} + +func TestAllowedRepetition_NonRepeatable_AlreadyUsed(t *testing.T) { + action := &Action{name: "test", repeatable: false} + + node1 := &node{Action: action, parentNode: nil} + node2 := &node{Action: &Action{name: "other"}, parentNode: node1} + + if allowedRepetition(action, node2) { + t.Error("Expected repeated non-repeatable action to be disallowed") + } +} + +// Test countMissingGoal +func TestCountMissingGoal_AllMet(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, true) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + } + + count := countMissingGoal(goal, agent.states) + if count != 0 { + t.Errorf("Expected 0 missing goals, got %d", count) + } +} + +func TestCountMissingGoal_OneMissing(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[bool](&agent, 2, false) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + } + + count := countMissingGoal(goal, agent.states) + if count != 1 { + t.Errorf("Expected 1 missing goal, got %d", count) + } +} + +func TestCountMissingGoal_AllMissing(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + SetState[bool](&agent, 2, false) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + } + + count := countMissingGoal(goal, agent.states) + if count != 2 { + t.Errorf("Expected 2 missing goals, got %d", count) + } +} + +// Test computeHeuristic +func TestComputeHeuristic(t *testing.T) { + fromAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](&fromAgent, 1, 0) + + currentAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](¤tAgent, 1, 50) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + heuristic := computeHeuristic(fromAgent.states, goal, currentAgent.states) + if heuristic <= 0 { + t.Error("Expected positive heuristic for unmet goal") + } +} + +func TestComputeHeuristic_GoalMet(t *testing.T) { + fromAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](&fromAgent, 1, 0) + + currentAgent := CreateAgent(Goals{}, Actions{}) + SetState[int](¤tAgent, 1, 100) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + heuristic := computeHeuristic(fromAgent.states, goal, currentAgent.states) + if heuristic != 0 { + t.Errorf("Expected 0 heuristic for met goal, got %f", heuristic) + } +} + +// Test astar integration +func TestAstar_SimpleGoal(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 10, Operator: ADD}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 30, Operator: EQUAL}, + }, + } + + plan := astar(agent.states, goal, actions, 10) + + // Plan includes the root node with empty action + if len(plan) != 4 { + t.Errorf("Expected plan with 4 actions (including root), got %d", len(plan)) + } +} + +func TestAstar_UnreachableGoal(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + // No actions available + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.states, goal, actions, 10) + + if len(plan) != 0 { + t.Errorf("Expected empty plan for unreachable goal, got %d actions", len(plan)) + } +} + +func TestAstar_MaxDepth(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 1, Operator: ADD}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + // Max depth of 5 should prevent reaching goal of 100 + plan := astar(agent.states, goal, actions, 5) + + if len(plan) > 5 { + t.Errorf("Expected plan to respect maxDepth of 5, got %d actions", len(plan)) + } +} + +func TestAstar_AlreadyAtGoal(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + actions := Actions{} + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.states, goal, actions, 10) + + // Plan includes root node with empty action when already at goal + if len(plan) != 1 { + t.Errorf("Expected plan with 1 action (root) when already at goal, got %d actions", len(plan)) + } +} + +func TestAstar_PreferLowerCost(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("expensive", 10.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }) + actions.AddAction("cheap", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.states, goal, actions, 10) + + // Plan includes root + action + if len(plan) != 2 { + t.Errorf("Expected plan with 2 actions (root + 1), got %d", len(plan)) + } + + if plan[1].name != "cheap" { + t.Errorf("Expected cheaper action to be chosen, got '%s'", plan[1].name) + } +} + +func TestAstar_RespectConditions(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + SetState[bool](&agent, 2, false) + + actions := Actions{} + // This action requires key 2 to be true + actions.AddAction("conditional", 1.0, false, Conditions{ + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, Effects{ + Effect[int]{Key: 1, Value: 100, Operator: SET}, + }) + // This action enables the conditional action + actions.AddAction("enabler", 1.0, false, Conditions{}, Effects{ + EffectBool{Key: 2, Value: true, Operator: SET}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + } + + plan := astar(agent.states, goal, actions, 10) + + // Plan includes root + 2 actions + if len(plan) != 3 { + t.Errorf("Expected plan with 3 actions (root + 2), got %d", len(plan)) + } + + // Second action (index 1) should be the enabler + if plan[1].name != "enabler" { + t.Errorf("Expected second action to be 'enabler', got '%s'", plan[1].name) + } + // Third action (index 2) should be the conditional one + if plan[2].name != "conditional" { + t.Errorf("Expected third action to be 'conditional', got '%s'", plan[2].name) + } +} + +func TestAstar_NonRepeatableActions(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 0) + + actions := Actions{} + actions.AddAction("increment", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 10, Operator: ADD}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 30, Operator: EQUAL}, + }, + } + + plan := astar(agent.states, goal, actions, 10) + + // Non-repeatable action can only be used once, so goal is unreachable + if len(plan) != 0 { + t.Errorf("Expected empty plan for non-repeatable action, got %d actions", len(plan)) + } +} + +func TestAstar_DataCloning(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + originalData := slices.Clone(agent.states.data) + + actions := Actions{} + actions.AddAction("modify", 1.0, false, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 200, Operator: SET}, + }) + + goal := goalInterface{ + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 200, Operator: EQUAL}, + }, + } + + _ = astar(agent.states, goal, actions, 10) + + // Original state should not be modified + if len(agent.states.data) != len(originalData) { + t.Error("Original state was modified") + } + if agent.states.data[0].(State[int]).Value != originalData[0].(State[int]).Value { + t.Error("Original state values were modified") + } +} diff --git a/planer_test.go b/planer_test.go index daf25d7..6163304 100644 --- a/planer_test.go +++ b/planer_test.go @@ -29,3 +29,378 @@ func TestPlan_GetTotalCost(t *testing.T) { }) } } + +// Test GetPlan +func TestGetPlan_SimpleGoal(t *testing.T) { + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 10, Operator: ADD}, + }) + + goals := Goals{ + "reach_30": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 30, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "reach_30" { + t.Errorf("Expected goal 'reach_30', got '%s'", goalName) + } + + // Plan includes root node + if len(plan) != 4 { + t.Errorf("Expected plan with 4 actions (including root), got %d", len(plan)) + } +} + +func TestGetPlan_AlreadyAtGoal(t *testing.T) { + actions := Actions{} + + goals := Goals{ + "be_at_100": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 100) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "be_at_100" { + t.Errorf("Expected goal 'be_at_100', got '%s'", goalName) + } + + // Plan includes root node even when already at goal + if len(plan) != 1 { + t.Errorf("Expected plan with 1 action (root) when already at goal, got %d actions", len(plan)) + } +} + +func TestGetPlan_NoGoalsAvailable(t *testing.T) { + actions := Actions{} + + goals := Goals{ + "impossible": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 0.0 // Zero priority means goal is not active + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "" { + t.Errorf("Expected empty goal name, got '%s'", goalName) + } + + if len(plan) != 0 { + t.Errorf("Expected empty plan when no goals available, got %d actions", len(plan)) + } +} + +func TestGetPlan_UnreachableGoal(t *testing.T) { + actions := Actions{} + // No actions to reach the goal + + goals := Goals{ + "unreachable": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + goalName, plan := GetPlan(agent, 10) + + if goalName != "unreachable" { + t.Errorf("Expected goal 'unreachable', got '%s'", goalName) + } + + if len(plan) != 0 { + t.Errorf("Expected empty plan for unreachable goal, got %d actions", len(plan)) + } +} + +func TestGetPlan_MaxDepthExceeded(t *testing.T) { + actions := Actions{} + actions.AddAction("increment", 1.0, true, Conditions{}, Effects{ + Effect[int]{Key: 1, Value: 1, Operator: ADD}, + }) + + goals := Goals{ + "reach_100": { + Conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[int](&agent, 1, 0) + + // Max depth of 5 should prevent reaching 100 + _, plan := GetPlan(agent, 5) + + if len(plan) > 5 { + t.Errorf("Expected plan to respect maxDepth of 5, got %d actions", len(plan)) + } +} + +// Test getPrioritizedGoalName +func TestGetPrioritizedGoalName_SingleGoal(t *testing.T) { + goals := Goals{ + "goal1": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + + goalName, err := agent.getPrioritizedGoalName() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if goalName != "goal1" { + t.Errorf("Expected 'goal1', got '%s'", goalName) + } +} + +func TestGetPrioritizedGoalName_MultipleGoals(t *testing.T) { + goals := Goals{ + "low_priority": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 0.5 + }, + }, + "high_priority": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 2.0 + }, + }, + "medium_priority": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + + goalName, err := agent.getPrioritizedGoalName() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if goalName != "high_priority" { + t.Errorf("Expected 'high_priority', got '%s'", goalName) + } +} + +func TestGetPrioritizedGoalName_ZeroPriority(t *testing.T) { + goals := Goals{ + "inactive": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 0.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + + _, err := agent.getPrioritizedGoalName() + + if err == nil { + t.Error("Expected error when all goals have zero priority") + } +} + +func TestGetPrioritizedGoalName_UsingSensors(t *testing.T) { + type Entity struct { + health int + } + + goals := Goals{ + "heal": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + entity := sensors.GetSensor("entity").(*Entity) + if entity.health < 50 { + return 2.0 + } + return 0.1 + }, + }, + "explore": { + Conditions: Conditions{}, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, Actions{}) + entity := &Entity{health: 20} + SetSensor(&agent, "entity", entity) + + goalName, err := agent.getPrioritizedGoalName() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if goalName != "heal" { + t.Errorf("Expected 'heal' for low health, got '%s'", goalName) + } +} + +// Test GetNextAction +func TestGetNextAction_WithActions(t *testing.T) { + plan := Plan{ + {name: "action1", cost: 1.0}, + {name: "action2", cost: 2.0}, + } + + action, err := plan.GetNextAction() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if action.name != "action1" { + t.Errorf("Expected 'action1', got '%s'", action.name) + } +} + +func TestGetNextAction_EmptyPlan(t *testing.T) { + plan := Plan{} + + _, err := plan.GetNextAction() + + if err == nil { + t.Error("Expected error for empty plan") + } +} + +func TestGetNextAction_SingleAction(t *testing.T) { + plan := Plan{ + {name: "only_action", cost: 1.0}, + } + + action, err := plan.GetNextAction() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if action.name != "only_action" { + t.Errorf("Expected 'only_action', got '%s'", action.name) + } +} + +// Integration test with complex scenario +func TestGetPlan_ComplexScenario(t *testing.T) { + actions := Actions{} + + // Need wood to make fire + actions.AddAction("get_wood", 2.0, false, Conditions{}, Effects{ + EffectBool{Key: 1, Value: true, Operator: SET}, // has_wood + }) + + // Need matches to make fire + actions.AddAction("get_matches", 1.0, false, Conditions{}, Effects{ + EffectBool{Key: 2, Value: true, Operator: SET}, // has_matches + }) + + // Make fire requires both wood and matches + actions.AddAction("make_fire", 1.0, false, Conditions{ + &ConditionBool{Key: 1, Value: true, Operator: EQUAL}, // has_wood + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, // has_matches + }, Effects{ + EffectBool{Key: 3, Value: true, Operator: SET}, // has_fire + }) + + goals := Goals{ + "stay_warm": { + Conditions: Conditions{ + &ConditionBool{Key: 3, Value: true, Operator: EQUAL}, // has_fire + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, + } + + agent := CreateAgent(goals, actions) + SetState[bool](&agent, 1, false) // has_wood + SetState[bool](&agent, 2, false) // has_matches + SetState[bool](&agent, 3, false) // has_fire + + goalName, plan := GetPlan(agent, 10) + + if goalName != "stay_warm" { + t.Errorf("Expected goal 'stay_warm', got '%s'", goalName) + } + + // Plan includes root + 3 actions + if len(plan) != 4 { + t.Errorf("Expected plan with 4 actions (root + 3), got %d", len(plan)) + } + + // Verify the plan makes sense (should get matches first as it's cheaper) + // Index 1 should be a resource gathering action + if plan[1].name != "get_matches" && plan[1].name != "get_wood" { + t.Errorf("Expected second action to gather resources, got '%s'", plan[1].name) + } + + // Last action should be make_fire + if plan[3].name != "make_fire" { + t.Errorf("Expected last action to be 'make_fire', got '%s'", plan[3].name) + } + + // Verify total cost (root has cost 0) + totalCost := plan.GetTotalCost() + if totalCost != 4.0 { + t.Errorf("Expected total cost 4.0, got %f", totalCost) + } +} diff --git a/state_test.go b/state_test.go new file mode 100644 index 0000000..9a19c5d --- /dev/null +++ b/state_test.go @@ -0,0 +1,438 @@ +package goapai + +import "testing" + +func TestState_Operations(t *testing.T) { + tests := []struct { + name string + testFunc func(*testing.T) + }{ + { + name: "GetKey", + testFunc: func(t *testing.T) { + state := State[int]{Key: 42, Value: 100} + if state.GetKey() != 42 { + t.Errorf("Expected key 42, got %d", state.GetKey()) + } + }, + }, + { + name: "GetValue", + testFunc: func(t *testing.T) { + state := State[int]{Key: 1, Value: 100} + if state.GetValue().(int) != 100 { + t.Errorf("Expected value 100, got %v", state.GetValue()) + } + }, + }, + { + name: "Check match", + testFunc: func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + state := State[int]{Key: 1, Value: 100} + if !state.Check(agent.states, 1) { + t.Error("Expected state to match") + } + }, + }, + { + name: "Check no match", + testFunc: func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + + wrongState := State[int]{Key: 1, Value: 200} + if wrongState.Check(agent.states, 1) { + t.Error("Expected state not to match") + } + }, + }, + { + name: "Check key not found", + testFunc: func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + state := State[int]{Key: 99, Value: 100} + + if state.Check(agent.states, 99) { + t.Error("Expected false for non-existent key") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.testFunc(t) + }) + } +} + +func TestStates_Check(t *testing.T) { + tests := []struct { + name string + setup1 func(*Agent) + setup2 func(*Agent) + wantMatch bool + }{ + { + name: "matching states", + setup1: func(a *Agent) { + SetState[int](a, 1, 100) + }, + setup2: func(a *Agent) { + SetState[int](a, 1, 100) + }, + wantMatch: true, + }, + { + name: "different states", + setup1: func(a *Agent) { + SetState[int](a, 1, 100) + }, + setup2: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[int](a, 2, 200) + }, + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent1 := CreateAgent(Goals{}, Actions{}) + tt.setup1(&agent1) + agent1.states.data.sort() + agent1.states.hash = agent1.states.data.hashStates() + + agent2 := CreateAgent(Goals{}, Actions{}) + tt.setup2(&agent2) + agent2.states.data.sort() + agent2.states.hash = agent2.states.data.hashStates() + + if got := agent1.states.Check(agent2.states); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } +} + +func TestStatesData_Operations(t *testing.T) { + tests := []struct { + name string + testFunc func(*testing.T) + }{ + { + name: "GetIndex found", + testFunc: func(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + State[bool]{Key: 3, Value: true}, + } + + idx := data.GetIndex(2) + if idx != 1 { + t.Errorf("Expected index 1, got %d", idx) + } + }, + }, + { + name: "GetIndex not found", + testFunc: func(t *testing.T) { + data := statesData{ + State[int]{Key: 1, Value: 100}, + } + + idx := data.GetIndex(99) + if idx != -1 { + t.Errorf("Expected index -1 for missing key, got %d", idx) + } + }, + }, + { + name: "sort", + testFunc: func(t *testing.T) { + data := statesData{ + State[int]{Key: 3, Value: 300}, + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + } + + data.sort() + + keys := []StateKey{1, 2, 3} + for i, expected := range keys { + if data[i].GetKey() != expected { + t.Errorf("Expected key %d at position %d, got %d", expected, i, data[i].GetKey()) + } + } + }, + }, + { + name: "hashStates same data", + testFunc: func(t *testing.T) { + data1 := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + } + + data2 := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + } + + hash1 := data1.hashStates() + hash2 := data2.hashStates() + + if hash1 != hash2 { + t.Error("Expected identical data to produce same hash") + } + }, + }, + { + name: "hashStates different data", + testFunc: func(t *testing.T) { + data1 := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + } + + data2 := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 999}, + } + + hash1 := data1.hashStates() + hash2 := data2.hashStates() + + if hash1 == hash2 { + t.Error("Expected different data to produce different hash") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.testFunc(t) + }) + } +} + +func TestCondition_Operators(t *testing.T) { + tests := []struct { + name string + stateVal int + condVal int + operator operator + wantMatch bool + }{ + {"equal match", 100, 100, EQUAL, true}, + {"equal no match", 100, 99, EQUAL, false}, + {"not equal match", 100, 99, NOT_EQUAL, true}, + {"not equal no match", 100, 100, NOT_EQUAL, false}, + {"lower true", 50, 100, LOWER, true}, + {"lower false equal", 50, 50, LOWER, false}, + {"lower false greater", 100, 50, LOWER, false}, + {"lower or equal true less", 50, 100, LOWER_OR_EQUAL, true}, + {"lower or equal true equal", 50, 50, LOWER_OR_EQUAL, true}, + {"lower or equal false", 100, 50, LOWER_OR_EQUAL, false}, + {"upper true", 100, 50, UPPER, true}, + {"upper false equal", 100, 100, UPPER, false}, + {"upper false less", 50, 100, UPPER, false}, + {"upper or equal true greater", 100, 50, UPPER_OR_EQUAL, true}, + {"upper or equal true equal", 100, 100, UPPER_OR_EQUAL, true}, + {"upper or equal false", 50, 100, UPPER_OR_EQUAL, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, tt.stateVal) + + condition := Condition[int]{Key: 1, Value: tt.condVal, Operator: tt.operator} + if got := condition.Check(agent.states); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v for state=%d, cond=%d, op=%v", + got, tt.wantMatch, tt.stateVal, tt.condVal, tt.operator) + } + }) + } +} + +func TestCondition_KeyNotFound(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + condition := Condition[int]{Key: 99, Value: 100, Operator: EQUAL} + + if condition.Check(agent.states) { + t.Error("Expected condition to fail when key not found") + } +} + +func TestConditionBool(t *testing.T) { + tests := []struct { + name string + stateVal bool + condVal bool + operator operator + wantMatch bool + }{ + {"equal true match", true, true, EQUAL, true}, + {"equal false match", false, false, EQUAL, true}, + {"equal no match", true, false, EQUAL, false}, + {"not equal match", true, false, NOT_EQUAL, true}, + {"not equal no match", true, true, NOT_EQUAL, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[bool](&agent, 1, tt.stateVal) + + condition := ConditionBool{Key: 1, Value: tt.condVal, Operator: tt.operator} + if got := condition.Check(agent.states); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } + + t.Run("key not found", func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + condition := ConditionBool{Key: 99, Value: true, Operator: EQUAL} + + if condition.Check(agent.states) { + t.Error("Expected condition to fail when key not found") + } + }) +} + +func TestConditionString(t *testing.T) { + tests := []struct { + name string + stateVal string + condVal string + operator operator + wantMatch bool + }{ + {"equal match", "test", "test", EQUAL, true}, + {"equal no match", "test", "other", EQUAL, false}, + {"not equal match", "test", "other", NOT_EQUAL, true}, + {"not equal no match", "test", "test", NOT_EQUAL, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[string](&agent, 1, tt.stateVal) + + condition := ConditionString{Key: 1, Value: tt.condVal, Operator: tt.operator} + if got := condition.Check(agent.states); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } +} + +func TestConditionFn(t *testing.T) { + tests := []struct { + name string + sensorVal int + threshold int + wantResult bool + }{ + {"above threshold", 100, 50, true}, + {"below threshold", 30, 50, false}, + {"at threshold", 50, 50, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetSensor(&agent, "value", tt.sensorVal) + + condition := &ConditionFn{ + Key: 1, + CheckFn: func(sensors Sensors) bool { + return sensors.GetSensor("value").(int) > tt.threshold + }, + } + + if got := condition.Check(agent.states); got != tt.wantResult { + t.Errorf("Check() = %v, want %v", got, tt.wantResult) + } + + // Test caching + if !condition.resolved { + t.Error("Expected condition to be marked as resolved") + } + + // Call again to test cache + if got := condition.Check(agent.states); got != tt.wantResult { + t.Error("Expected cached result to match") + } + }) + } +} + +func TestConditions_Check(t *testing.T) { + tests := []struct { + name string + setup func(*Agent) + conditions Conditions + wantMatch bool + }{ + { + name: "all match", + setup: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[bool](a, 2, true) + }, + conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: true, Operator: EQUAL}, + }, + wantMatch: true, + }, + { + name: "one fails", + setup: func(a *Agent) { + SetState[int](a, 1, 100) + SetState[bool](a, 2, true) + }, + conditions: Conditions{ + &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + &ConditionBool{Key: 2, Value: false, Operator: EQUAL}, + }, + wantMatch: false, + }, + { + name: "empty conditions", + setup: func(a *Agent) {}, + conditions: Conditions{}, + wantMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + tt.setup(&agent) + + if got := tt.conditions.Check(agent.states); got != tt.wantMatch { + t.Errorf("Check() = %v, want %v", got, tt.wantMatch) + } + }) + } +} + +func TestSensors_GetSensor(t *testing.T) { + sensors := Sensors{ + "test": 42, + } + + value := sensors.GetSensor("test") + if value.(int) != 42 { + t.Errorf("Expected 42, got %v", value) + } +} From 25daa0ca940ff5ba5125481f165688796f3294d0 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Wed, 29 Oct 2025 09:10:19 +0100 Subject: [PATCH 02/18] feat: incremental hash to improve performances --- action.go | 9 ++ astar.go | 22 ++++- astar_test.go | 4 - state.go | 106 ++++++++++++---------- state_test.go | 242 ++++++++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 313 insertions(+), 70 deletions(-) diff --git a/action.go b/action.go index 7918495..780a0ae 100644 --- a/action.go +++ b/action.go @@ -45,6 +45,7 @@ func (action *Action) GetEffects() Effects { } type EffectInterface interface { + GetKey() StateKey check(states states) bool apply(data statesData) error } @@ -55,6 +56,10 @@ type Effect[T Numeric] struct { Value T } +func (effect Effect[T]) GetKey() StateKey { + return effect.Key +} + func (effect Effect[T]) check(states states) bool { // Other operators than '=' mean the effect will have an impact of the states if effect.Operator != SET { @@ -115,6 +120,10 @@ type EffectBool struct { Operator arithmetic } +func (effectBool EffectBool) GetKey() StateKey { + return effectBool.Key +} + func (effectBool EffectBool) check(states states) bool { // Other operators than '=' is not allowed if effectBool.Operator != SET { diff --git a/astar.go b/astar.go index 1d358d2..e886239 100644 --- a/astar.go +++ b/astar.go @@ -33,7 +33,6 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan }() data := slices.Clone(from.data) - data.sort() openNodes = append(openNodes, &node{ Action: &Action{}, states: states{ @@ -182,11 +181,28 @@ func simulateActionState(action *Action, nodeStates states) (states, bool) { return states{}, false } - data.sort() + // Calculate hash incrementally by tracking changes + newHash := nodeStates.hash + + // For each effect, we need to XOR out the old state and XOR in the new state + for _, effect := range action.effects { + // Find old state if it exists + oldIndex := nodeStates.data.GetIndex(effect.GetKey()) + if oldIndex >= 0 { + newHash ^= nodeStates.data[oldIndex].Hash() // Remove old + } + + // Find new state in modified data + newIndex := data.GetIndex(effect.GetKey()) + if newIndex >= 0 { + newHash ^= data[newIndex].Hash() // Add new + } + } + return states{ Agent: nodeStates.Agent, data: data, - hash: data.hashStates(), + hash: newHash, }, true } diff --git a/astar_test.go b/astar_test.go index 642c438..87d1476 100644 --- a/astar_test.go +++ b/astar_test.go @@ -79,12 +79,10 @@ func TestGetLessCostlyNodeKey_Empty(t *testing.T) { func TestFetchNode_Found(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent1, 1, 100) - agent1.states.data.sort() agent1.states.hash = agent1.states.data.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent2, 1, 100) - agent2.states.data.sort() agent2.states.hash = agent2.states.data.hashStates() nodes := []*node{ @@ -104,12 +102,10 @@ func TestFetchNode_Found(t *testing.T) { func TestFetchNode_NotFound(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent1, 1, 100) - agent1.states.data.sort() agent1.states.hash = agent1.states.data.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent2, 1, 200) - agent2.states.data.sort() agent2.states.hash = agent2.states.data.hashStates() nodes := []*node{ diff --git a/state.go b/state.go index 59bc79f..325712e 100644 --- a/state.go +++ b/state.go @@ -3,8 +3,7 @@ package goapai import ( "encoding/binary" "hash/fnv" - "slices" - "strconv" + "math" ) type operator uint8 @@ -20,14 +19,15 @@ const ( type Numeric interface { ~int8 | ~int | - ~uint8 | ~uint64 | - ~float64 + ~uint8 | ~uint64 | + ~float64 } type StateInterface interface { Check(states states, key StateKey) bool GetKey() StateKey GetValue() any + Hash() uint64 } type State[T Numeric | bool | string] struct { Key StateKey @@ -67,6 +67,43 @@ func (state State[T]) GetValue() any { return state.Value } +// Hash returns a unique hash for this state using FNV-64a +func (state State[T]) Hash() uint64 { + h := fnv.New64a() + + // Write the key + buf := make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(state.Key)) + h.Write(buf) + + // Write the value based on type + buf = make([]byte, 8) + switch v := any(state.Value).(type) { + case int8: + binary.LittleEndian.PutUint64(buf, uint64(v)) + case int: + binary.LittleEndian.PutUint64(buf, uint64(v)) + case uint8: + binary.LittleEndian.PutUint64(buf, uint64(v)) + case uint64: + binary.LittleEndian.PutUint64(buf, v) + case float64: + binary.LittleEndian.PutUint64(buf, math.Float64bits(v)) + case bool: + if v { + binary.LittleEndian.PutUint64(buf, 1) + } else { + binary.LittleEndian.PutUint64(buf, 0) + } + case string: + h.Write([]byte(v)) + return h.Sum64() + } + h.Write(buf) + + return h.Sum64() +} + // Check compares states and states2 by their hash. func (states states) Check(states2 states) bool { return states.hash == states2.hash @@ -82,53 +119,26 @@ func (statesData statesData) GetIndex(stateKey StateKey) int { return -1 } -func (statesData statesData) sort() { - slices.SortFunc(statesData, func(a, b StateInterface) int { - if a.GetKey() > b.GetKey() { - return 1 - } else if a.GetKey() < b.GetKey() { - return -1 - } - - return 0 - }) -} - +// hashStates computes the initial hash using XOR of individual state hashes +// This is O(n) but only called once when creating initial state func (statesData statesData) hashStates() uint64 { - hash := fnv.New64() - - buf := make([]byte, binary.MaxVarintLen64) - for _, data := range statesData { - n := binary.PutVarint(buf, int64(data.GetKey())) - hash.Write(buf[:n]) - hash.Write([]byte(":")) - - switch v := data.GetValue().(type) { - case int8: - n = binary.PutVarint(buf, int64(v)) - hash.Write(buf[:n]) - case int: - n = binary.PutVarint(buf, int64(v)) - hash.Write(buf[:n]) - case uint8: - n = binary.PutUvarint(buf, uint64(v)) - hash.Write(buf[:n]) - case uint64: - n = binary.PutUvarint(buf, v) - hash.Write(buf[:n]) - case float64: - hash.Write([]byte(strconv.FormatFloat(v, 'f', -1, 64))) - case string: - hash.Write([]byte(v)) - case []byte: - hash.Write(v) - default: - binary.Write(hash, binary.LittleEndian, data.GetValue()) - } - hash.Write([]byte(";")) + var hash uint64 = 0 + for _, state := range statesData { + hash ^= state.Hash() // XOR for incremental updates } + return hash +} - return hash.Sum64() +// updateHashIncremental updates a hash by removing old state and adding new state +// This is O(1) - the key optimization +func updateHashIncremental(currentHash uint64, oldState, newState StateInterface) uint64 { + if oldState != nil { + currentHash ^= oldState.Hash() // Remove old + } + if newState != nil { + currentHash ^= newState.Hash() // Add new + } + return currentHash } type Sensor any diff --git a/state_test.go b/state_test.go index 9a19c5d..ef3fa4d 100644 --- a/state_test.go +++ b/state_test.go @@ -4,8 +4,8 @@ import "testing" func TestState_Operations(t *testing.T) { tests := []struct { - name string - testFunc func(*testing.T) + name string + testFunc func(*testing.T) }{ { name: "GetKey", @@ -103,12 +103,10 @@ func TestStates_Check(t *testing.T) { t.Run(tt.name, func(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) tt.setup1(&agent1) - agent1.states.data.sort() agent1.states.hash = agent1.states.data.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) tt.setup2(&agent2) - agent2.states.data.sort() agent2.states.hash = agent2.states.data.hashStates() if got := agent1.states.Check(agent2.states); got != tt.wantMatch { @@ -152,33 +150,36 @@ func TestStatesData_Operations(t *testing.T) { }, }, { - name: "sort", + name: "hashStates same data", testFunc: func(t *testing.T) { - data := statesData{ - State[int]{Key: 3, Value: 300}, + data1 := statesData{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, } - data.sort() + data2 := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + } + + hash1 := data1.hashStates() + hash2 := data2.hashStates() - keys := []StateKey{1, 2, 3} - for i, expected := range keys { - if data[i].GetKey() != expected { - t.Errorf("Expected key %d at position %d, got %d", expected, i, data[i].GetKey()) - } + if hash1 != hash2 { + t.Error("Expected identical data to produce same hash") } }, }, { - name: "hashStates same data", + name: "hashStates same data, different keys", testFunc: func(t *testing.T) { data1 := statesData{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, + State[int]{Key: 3, Value: 300}, } - data2 := statesData{ + State[int]{Key: 3, Value: 300}, State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, } @@ -436,3 +437,214 @@ func TestSensors_GetSensor(t *testing.T) { t.Errorf("Expected 42, got %v", value) } } + +func TestState_Hash(t *testing.T) { + tests := []struct { + name string + state1 StateInterface + state2 StateInterface + wantSame bool + }{ + { + name: "same int states", + state1: State[int]{Key: 1, Value: 100}, + state2: State[int]{Key: 1, Value: 100}, + wantSame: true, + }, + { + name: "different int values", + state1: State[int]{Key: 1, Value: 100}, + state2: State[int]{Key: 1, Value: 200}, + wantSame: false, + }, + { + name: "different keys", + state1: State[int]{Key: 1, Value: 100}, + state2: State[int]{Key: 2, Value: 100}, + wantSame: false, + }, + { + name: "same bool states", + state1: State[bool]{Key: 1, Value: true}, + state2: State[bool]{Key: 1, Value: true}, + wantSame: true, + }, + { + name: "different bool values", + state1: State[bool]{Key: 1, Value: true}, + state2: State[bool]{Key: 1, Value: false}, + wantSame: false, + }, + { + name: "same string states", + state1: State[string]{Key: 1, Value: "test"}, + state2: State[string]{Key: 1, Value: "test"}, + wantSame: true, + }, + { + name: "different string values", + state1: State[string]{Key: 1, Value: "test"}, + state2: State[string]{Key: 1, Value: "other"}, + wantSame: false, + }, + { + name: "same float64 states", + state1: State[float64]{Key: 1, Value: 3.14}, + state2: State[float64]{Key: 1, Value: 3.14}, + wantSame: true, + }, + { + name: "different float64 values", + state1: State[float64]{Key: 1, Value: 3.14}, + state2: State[float64]{Key: 1, Value: 2.71}, + wantSame: false, + }, + { + name: "same uint64 states", + state1: State[uint64]{Key: 1, Value: 12345}, + state2: State[uint64]{Key: 1, Value: 12345}, + wantSame: true, + }, + { + name: "different uint64 values", + state1: State[uint64]{Key: 1, Value: 12345}, + state2: State[uint64]{Key: 1, Value: 54321}, + wantSame: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash1 := tt.state1.Hash() + hash2 := tt.state2.Hash() + + if tt.wantSame && hash1 != hash2 { + t.Errorf("Expected same hash, got %d and %d", hash1, hash2) + } + if !tt.wantSame && hash1 == hash2 { + t.Errorf("Expected different hashes, both got %d", hash1) + } + }) + } +} + +func TestUpdateHashIncremental(t *testing.T) { + tests := []struct { + name string + currentHash uint64 + oldState StateInterface + newState StateInterface + verify func(*testing.T, uint64) + }{ + { + name: "add new state", + currentHash: 0, + oldState: nil, + newState: State[int]{Key: 1, Value: 100}, + verify: func(t *testing.T, result uint64) { + expected := State[int]{Key: 1, Value: 100}.Hash() + if result != expected { + t.Errorf("Expected hash %d, got %d", expected, result) + } + }, + }, + { + name: "remove state", + currentHash: State[int]{Key: 1, Value: 100}.Hash(), + oldState: State[int]{Key: 1, Value: 100}, + newState: nil, + verify: func(t *testing.T, result uint64) { + if result != 0 { + t.Errorf("Expected hash 0, got %d", result) + } + }, + }, + { + name: "replace state", + currentHash: State[int]{Key: 1, Value: 100}.Hash(), + oldState: State[int]{Key: 1, Value: 100}, + newState: State[int]{Key: 1, Value: 200}, + verify: func(t *testing.T, result uint64) { + expected := State[int]{Key: 1, Value: 200}.Hash() + if result != expected { + t.Errorf("Expected hash %d, got %d", expected, result) + } + }, + }, + { + name: "XOR property - multiple states", + currentHash: func() uint64 { + var h uint64 + h ^= State[int]{Key: 1, Value: 100}.Hash() + h ^= State[int]{Key: 2, Value: 200}.Hash() + return h + }(), + oldState: State[int]{Key: 1, Value: 100}, + newState: State[int]{Key: 1, Value: 150}, + verify: func(t *testing.T, result uint64) { + expected := State[int]{Key: 1, Value: 150}.Hash() ^ State[int]{Key: 2, Value: 200}.Hash() + if result != expected { + t.Errorf("Expected hash %d, got %d", expected, result) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateHashIncremental(tt.currentHash, tt.oldState, tt.newState) + tt.verify(t, result) + }) + } +} + +func TestStatesData_HashStates_XORProperty(t *testing.T) { + // Test that hash order doesn't matter (XOR is commutative) + data1 := statesData{ + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + State[bool]{Key: 3, Value: true}, + } + + data2 := statesData{ + State[bool]{Key: 3, Value: true}, + State[int]{Key: 1, Value: 100}, + State[int]{Key: 2, Value: 200}, + } + + hash1 := data1.hashStates() + hash2 := data2.hashStates() + + if hash1 != hash2 { + t.Error("Expected XOR hash to be order-independent") + } +} + +func TestIncrementalHashConsistency(t *testing.T) { + // Verify that incremental hash matches full recalculation + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, 1, 100) + SetState[int](&agent, 2, 200) + SetState[bool](&agent, 3, true) + agent.states.hash = agent.states.data.hashStates() + + initialHash := agent.states.hash + + // Modify a state and calculate hash incrementally + oldState := agent.states.data[0] + newState := State[int]{Key: 1, Value: 150} + + incrementalHash := updateHashIncremental(agent.states.hash, oldState, newState) + + // Modify the same state and calculate hash from scratch + agent.states.data[0] = newState + fullHash := agent.states.data.hashStates() + + if incrementalHash != fullHash { + t.Errorf("Incremental hash (%d) doesn't match full recalculation (%d)", incrementalHash, fullHash) + } + + if incrementalHash == initialHash { + t.Error("Hash should change after state modification") + } +} From 2ea0b01c3deeec37cb7ec6e11fd42e91cf2610a9 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Wed, 29 Oct 2025 09:39:08 +0100 Subject: [PATCH 03/18] feat: add effectString.GetKey --- action.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/action.go b/action.go index 780a0ae..8d1dd5a 100644 --- a/action.go +++ b/action.go @@ -170,6 +170,10 @@ type EffectString struct { Operator arithmetic } +func (effectString EffectString) GetKey() StateKey { + return effectString.Key +} + func (effectString EffectString) check(states states) bool { k := states.data.GetIndex(effectString.Key) if k < 0 { From 1c7bec5325c4edeb5a6105170962e852eceb57d8 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Wed, 29 Oct 2025 22:38:11 +0100 Subject: [PATCH 04/18] feat: rename vars --- action.go | 36 ++++++++++++++++---------------- action_test.go | 6 +++--- agent.go | 4 ++-- agent_test.go | 6 +++--- astar.go | 56 +++++++++++++++++++++++++------------------------- astar_test.go | 8 ++++---- state.go | 48 +++++++++++++++++++++---------------------- state_test.go | 16 +++++++-------- 8 files changed, 90 insertions(+), 90 deletions(-) diff --git a/action.go b/action.go index 8d1dd5a..65e80c2 100644 --- a/action.go +++ b/action.go @@ -46,7 +46,7 @@ func (action *Action) GetEffects() Effects { type EffectInterface interface { GetKey() StateKey - check(states states) bool + check(w world) bool apply(data statesData) error } @@ -60,17 +60,17 @@ func (effect Effect[T]) GetKey() StateKey { return effect.Key } -func (effect Effect[T]) check(states states) bool { - // Other operators than '=' mean the effect will have an impact of the states +func (effect Effect[T]) check(w world) bool { + // Other operators than '=' mean the effect will have an impact of the world if effect.Operator != SET { return false } - k := states.data.GetIndex(effect.Key) + k := w.data.GetIndex(effect.Key) if k < 0 { return false } - s := states.data[k] + s := w.data[k] if _, ok := s.(State[T]); !ok { return false @@ -124,21 +124,21 @@ func (effectBool EffectBool) GetKey() StateKey { return effectBool.Key } -func (effectBool EffectBool) check(states states) bool { +func (effectBool EffectBool) check(w world) bool { // Other operators than '=' is not allowed if effectBool.Operator != SET { return false } - k := states.data.GetIndex(effectBool.Key) + k := w.data.GetIndex(effectBool.Key) if k < 0 { return false } - if _, ok := states.data[k].(State[bool]); !ok { + if _, ok := w.data[k].(State[bool]); !ok { return false } - s := states.data[k].(State[bool]) + s := w.data[k].(State[bool]) return s.Value == effectBool.Value } @@ -174,16 +174,16 @@ func (effectString EffectString) GetKey() StateKey { return effectString.Key } -func (effectString EffectString) check(states states) bool { - k := states.data.GetIndex(effectString.Key) +func (effectString EffectString) check(w world) bool { + k := w.data.GetIndex(effectString.Key) if k < 0 { return false } - if _, ok := states.data[k].(State[string]); !ok { + if _, ok := w.data[k].(State[string]); !ok { return false } - s := states.data[k].(State[string]) + s := w.data[k].(State[string]) return s.Value == effectString.Value } @@ -218,11 +218,11 @@ type EffectFn func(agent *Agent) type Effects []EffectInterface -// If all the effects already exist in states, +// If all the effects already exist in world, // it is probably not a good path -func (effects Effects) satisfyStates(states states) bool { +func (effects Effects) satisfyStates(w world) bool { for _, effect := range effects { - if !effect.check(states) { + if !effect.check(w) { return false } } @@ -230,8 +230,8 @@ func (effects Effects) satisfyStates(states states) bool { return true } -func (effects Effects) apply(states states) (statesData, error) { - data := slices.Clone(states.data) +func (effects Effects) apply(w world) (statesData, error) { + data := slices.Clone(w.data) for _, effect := range effects { err := effect.apply(data) diff --git a/action_test.go b/action_test.go index 51c7c22..107829f 100644 --- a/action_test.go +++ b/action_test.go @@ -350,7 +350,7 @@ func TestEffects_SatisfyStates_AllMatch(t *testing.T) { } if !effects.satisfyStates(agent.states) { - t.Error("Expected effects to satisfy states") + t.Error("Expected effects to satisfy world") } } @@ -365,7 +365,7 @@ func TestEffects_SatisfyStates_OneFails(t *testing.T) { } if effects.satisfyStates(agent.states) { - t.Error("Expected effects not to satisfy states when one doesn't match") + t.Error("Expected effects not to satisfy world when one doesn't match") } } @@ -385,7 +385,7 @@ func TestEffects_Apply(t *testing.T) { } if len(newData) != 2 { - t.Errorf("Expected 2 states, got %d", len(newData)) + t.Errorf("Expected 2 world, got %d", len(newData)) } // Check int value diff --git a/agent.go b/agent.go index fba27a1..c22cb78 100644 --- a/agent.go +++ b/agent.go @@ -2,7 +2,7 @@ package goapai type Agent struct { actions Actions - states states + states world sensors Sensors goals Goals } @@ -21,7 +21,7 @@ func CreateAgent(goals Goals, actions Actions) Agent { sensors: Sensors{}, } - states := states{ + states := world{ Agent: &agent, data: statesData{}, } diff --git a/agent_test.go b/agent_test.go index e8905ee..6021331 100644 --- a/agent_test.go +++ b/agent_test.go @@ -32,7 +32,7 @@ func TestCreateAgent(t *testing.T) { } if agent.states.Agent == nil { - t.Error("Expected states.Agent to be non-nil") + t.Error("Expected world.Agent to be non-nil") } } @@ -88,7 +88,7 @@ func TestSetState(t *testing.T) { }, }, { - name: "multiple states", + name: "multiple world", setupFunc: func(a *Agent) { SetState[int](a, 1, 100) SetState[bool](a, 2, false) @@ -96,7 +96,7 @@ func TestSetState(t *testing.T) { }, checkFunc: func(t *testing.T, a Agent) { if len(a.states.data) != 3 { - t.Errorf("Expected 3 states, got %d", len(a.states.data)) + t.Errorf("Expected 3 world, got %d", len(a.states.data)) } }, }, diff --git a/astar.go b/astar.go index e886239..7aa4a47 100644 --- a/astar.go +++ b/astar.go @@ -7,7 +7,7 @@ import ( type node struct { *Action - states states + world world parentNode *node cost float32 @@ -22,7 +22,7 @@ var nodesPool = sync.Pool{ }, } -func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan { +func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { availableActions := getImpactingActions(from, actions) openNodes := nodesPool.Get().([]*node) closedNodes := nodesPool.Get().([]*node) @@ -35,7 +35,7 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan data := slices.Clone(from.data) openNodes = append(openNodes, &node{ Action: &Action{}, - states: states{ + world: world{ Agent: from.Agent, data: data, hash: data.hashStates(), @@ -56,7 +56,7 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan } // Simulate world state, and check if we are at current state - if countMissingGoal(goal, parentNode.states) == 0 { + if countMissingGoal(goal, parentNode.world) == 0 { return buildPlanFromNode(parentNode) } @@ -65,11 +65,11 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan continue } - if !action.conditions.Check(parentNode.states) { + if !action.conditions.Check(parentNode.world) { continue } - simulatedStates, ok := simulateActionState(action, parentNode.states) + simulatedStates, ok := simulateActionState(action, parentNode.world) if !ok { continue } @@ -78,7 +78,7 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan node := openNodes[nodeKey] if (parentNode.cost + action.cost) < node.cost { node.Action = action - node.states = simulatedStates + node.world = simulatedStates node.parentNode = parentNode node.cost = parentNode.cost + action.cost node.totalCost = parentNode.cost + action.cost + node.heuristic @@ -90,7 +90,7 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan node := closedNodes[nodeKey] if (parentNode.cost + action.cost) < node.cost { node.Action = action - node.states = simulatedStates + node.world = simulatedStates node.parentNode = parentNode node.cost = parentNode.cost + action.cost node.totalCost = parentNode.cost + action.cost + node.heuristic @@ -103,7 +103,7 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan heuristic := computeHeuristic(from, goal, simulatedStates) openNodes = append(openNodes, &node{ Action: action, - states: simulatedStates, + world: simulatedStates, parentNode: parentNode, cost: parentNode.cost + action.cost, totalCost: parentNode.cost + action.cost + heuristic, @@ -120,9 +120,9 @@ func astar(from states, goal goalInterface, actions Actions, maxDepth int) Plan return Plan{} } -// All the actions similar to initial states are useless: +// All the actions similar to initial world are useless: // we consider they are not going towards the goal and are dead end -func getImpactingActions(from states, actions Actions) Actions { +func getImpactingActions(from world, actions Actions) Actions { var availableActions Actions for _, action := range actions { @@ -146,9 +146,9 @@ func getLessCostlyNodeKey(openNodes []*node) int { return lowestKey } -func fetchNode(nodes []*node, states states) (int, bool) { +func fetchNode(nodes []*node, w world) (int, bool) { for k, n := range nodes { - if n.states.Check(states) { + if n.world.Check(w) { return k, true } } @@ -169,27 +169,27 @@ func buildPlanFromNode(node *node) Plan { return plan } -func simulateActionState(action *Action, nodeStates states) (states, bool) { +func simulateActionState(action *Action, w world) (world, bool) { /* If action effects implies no changes to current worldState, then avoid generating huge chunks of memory */ - if action.effects.satisfyStates(nodeStates) { - return states{}, false + if action.effects.satisfyStates(w) { + return world{}, false } - data, err := action.effects.apply(nodeStates) + data, err := action.effects.apply(w) if err != nil { - return states{}, false + return world{}, false } // Calculate hash incrementally by tracking changes - newHash := nodeStates.hash + newHash := w.hash // For each effect, we need to XOR out the old state and XOR in the new state for _, effect := range action.effects { // Find old state if it exists - oldIndex := nodeStates.data.GetIndex(effect.GetKey()) + oldIndex := w.data.GetIndex(effect.GetKey()) if oldIndex >= 0 { - newHash ^= nodeStates.data[oldIndex].Hash() // Remove old + newHash ^= w.data[oldIndex].Hash() // Remove old } // Find new state in modified data @@ -199,8 +199,8 @@ func simulateActionState(action *Action, nodeStates states) (states, bool) { } } - return states{ - Agent: nodeStates.Agent, + return world{ + Agent: w.Agent, data: data, hash: newHash, }, true @@ -223,10 +223,10 @@ func allowedRepetition(action *Action, parentNode *node) bool { return true } -func countMissingGoal(goal goalInterface, states states) int { +func countMissingGoal(goal goalInterface, w world) int { count := 0 for _, condition := range goal.Conditions { - if !condition.Check(states) { + if !condition.Check(w) { count++ } } @@ -236,12 +236,12 @@ func countMissingGoal(goal goalInterface, states states) int { /* A very simple (empiristic) model for h using: - - how much required states are met + - how much required world are met We try to be conservative and reduce the number of steps */ -func computeHeuristic(fromStates states, goal goalInterface, states states) float32 { - missingGoalsCount := float32(countMissingGoal(goal, states)) +func computeHeuristic(from world, goal goalInterface, w world) float32 { + missingGoalsCount := float32(countMissingGoal(goal, w)) h := missingGoalsCount diff --git a/astar_test.go b/astar_test.go index 87d1476..3e1d5e6 100644 --- a/astar_test.go +++ b/astar_test.go @@ -86,8 +86,8 @@ func TestFetchNode_Found(t *testing.T) { agent2.states.hash = agent2.states.data.hashStates() nodes := []*node{ - {states: agent1.states}, - {states: agent2.states}, + {world: agent1.states}, + {world: agent2.states}, } key, found := fetchNode(nodes, agent1.states) @@ -109,7 +109,7 @@ func TestFetchNode_NotFound(t *testing.T) { agent2.states.hash = agent2.states.data.hashStates() nodes := []*node{ - {states: agent1.states}, + {world: agent1.states}, } _, found := fetchNode(nodes, agent2.states) @@ -175,7 +175,7 @@ func TestSimulateActionState(t *testing.T) { idx := newStates.data.GetIndex(1) if idx < 0 { - t.Error("Expected to find key 1 in new states") + t.Error("Expected to find key 1 in new world") } if newStates.data[idx].(State[int]).Value != 150 { diff --git a/state.go b/state.go index 325712e..40b9b3c 100644 --- a/state.go +++ b/state.go @@ -19,12 +19,12 @@ const ( type Numeric interface { ~int8 | ~int | - ~uint8 | ~uint64 | - ~float64 + ~uint8 | ~uint64 | + ~float64 } type StateInterface interface { - Check(states states, key StateKey) bool + Check(w world, key StateKey) bool GetKey() StateKey GetValue() any Hash() uint64 @@ -38,7 +38,7 @@ type StateKey uint16 type statesData []StateInterface -type states struct { +type world struct { Agent *Agent data statesData hash uint64 @@ -48,12 +48,12 @@ func (state State[T]) GetKey() StateKey { return state.Key } -func (state State[T]) Check(states states, key StateKey) bool { - k := states.data.GetIndex(key) +func (state State[T]) Check(w world, key StateKey) bool { + k := w.data.GetIndex(key) if k < 0 { return false } - s := states.data[k] + s := w.data[k] if agentState, ok := s.(State[T]); ok { if agentState.Value == state.Value { return true @@ -104,9 +104,9 @@ func (state State[T]) Hash() uint64 { return h.Sum64() } -// Check compares states and states2 by their hash. -func (states states) Check(states2 states) bool { - return states.hash == states2.hash +// Check compares world and states2 by their hash. +func (world world) Check(world2 world) bool { + return world.hash == world2.hash } func (statesData statesData) GetIndex(stateKey StateKey) int { @@ -150,7 +150,7 @@ func (sensors Sensors) GetSensor(name string) Sensor { type ConditionInterface interface { GetKey() StateKey - Check(states states) bool + Check(w world) bool } type ConditionFn struct { @@ -164,9 +164,9 @@ func (conditionFn *ConditionFn) GetKey() StateKey { return conditionFn.Key } -func (conditionFn *ConditionFn) Check(states states) bool { +func (conditionFn *ConditionFn) Check(w world) bool { if !conditionFn.resolved { - conditionFn.valid = conditionFn.CheckFn(states.Agent.sensors) + conditionFn.valid = conditionFn.CheckFn(w.Agent.sensors) conditionFn.resolved = true } @@ -183,12 +183,12 @@ func (condition *Condition[T]) GetKey() StateKey { return condition.Key } -func (condition *Condition[T]) Check(states states) bool { - k := states.data.GetIndex(condition.Key) +func (condition *Condition[T]) Check(w world) bool { + k := w.data.GetIndex(condition.Key) if k < 0 { return false } - s := states.data[k] + s := w.data[k] if state, ok := s.(State[T]); ok { switch condition.Operator { case EQUAL: @@ -231,12 +231,12 @@ func (conditionBool *ConditionBool) GetKey() StateKey { return conditionBool.Key } -func (conditionBool *ConditionBool) Check(states states) bool { - k := states.data.GetIndex(conditionBool.Key) +func (conditionBool *ConditionBool) Check(w world) bool { + k := w.data.GetIndex(conditionBool.Key) if k < 0 { return false } - s := states.data[k] + s := w.data[k] if state, ok := s.(State[bool]); ok { switch conditionBool.Operator { case EQUAL: @@ -265,12 +265,12 @@ func (conditionString *ConditionString) GetKey() StateKey { return conditionString.Key } -func (conditionString *ConditionString) Check(states states) bool { - k := states.data.GetIndex(conditionString.Key) +func (conditionString *ConditionString) Check(w world) bool { + k := w.data.GetIndex(conditionString.Key) if k < 0 { return false } - s := states.data[k] + s := w.data[k] if state, ok := s.(State[string]); ok { switch conditionString.Operator { case EQUAL: @@ -291,9 +291,9 @@ func (conditionString *ConditionString) Check(states states) bool { type Conditions []ConditionInterface -func (conditions Conditions) Check(states states) bool { +func (conditions Conditions) Check(w world) bool { for _, condition := range conditions { - if !condition.Check(states) { + if !condition.Check(w) { return false } } diff --git a/state_test.go b/state_test.go index ef3fa4d..4384fa6 100644 --- a/state_test.go +++ b/state_test.go @@ -77,7 +77,7 @@ func TestStates_Check(t *testing.T) { wantMatch bool }{ { - name: "matching states", + name: "matching world", setup1: func(a *Agent) { SetState[int](a, 1, 100) }, @@ -87,7 +87,7 @@ func TestStates_Check(t *testing.T) { wantMatch: true, }, { - name: "different states", + name: "different world", setup1: func(a *Agent) { SetState[int](a, 1, 100) }, @@ -446,7 +446,7 @@ func TestState_Hash(t *testing.T) { wantSame bool }{ { - name: "same int states", + name: "same int world", state1: State[int]{Key: 1, Value: 100}, state2: State[int]{Key: 1, Value: 100}, wantSame: true, @@ -464,7 +464,7 @@ func TestState_Hash(t *testing.T) { wantSame: false, }, { - name: "same bool states", + name: "same bool world", state1: State[bool]{Key: 1, Value: true}, state2: State[bool]{Key: 1, Value: true}, wantSame: true, @@ -476,7 +476,7 @@ func TestState_Hash(t *testing.T) { wantSame: false, }, { - name: "same string states", + name: "same string world", state1: State[string]{Key: 1, Value: "test"}, state2: State[string]{Key: 1, Value: "test"}, wantSame: true, @@ -488,7 +488,7 @@ func TestState_Hash(t *testing.T) { wantSame: false, }, { - name: "same float64 states", + name: "same float64 world", state1: State[float64]{Key: 1, Value: 3.14}, state2: State[float64]{Key: 1, Value: 3.14}, wantSame: true, @@ -500,7 +500,7 @@ func TestState_Hash(t *testing.T) { wantSame: false, }, { - name: "same uint64 states", + name: "same uint64 world", state1: State[uint64]{Key: 1, Value: 12345}, state2: State[uint64]{Key: 1, Value: 12345}, wantSame: true, @@ -572,7 +572,7 @@ func TestUpdateHashIncremental(t *testing.T) { }, }, { - name: "XOR property - multiple states", + name: "XOR property - multiple world", currentHash: func() uint64 { var h uint64 h ^= State[int]{Key: 1, Value: 100}.Hash() From 6bd45f492febca9df59bc2c28e3a84bbf35981cd Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Wed, 29 Oct 2025 23:39:45 +0100 Subject: [PATCH 05/18] feat: wip --- action.go | 30 +++++++++++++++--------------- action_test.go | 26 +++++++++++++------------- agent.go | 6 +++--- agent_test.go | 22 +++++++++++----------- astar.go | 20 ++++++++++---------- astar_test.go | 20 ++++++++++---------- state.go | 32 ++++++++++++++++---------------- state_test.go | 44 ++++++++++++++++++++++---------------------- 8 files changed, 100 insertions(+), 100 deletions(-) diff --git a/action.go b/action.go index 65e80c2..5bda5e4 100644 --- a/action.go +++ b/action.go @@ -47,7 +47,7 @@ func (action *Action) GetEffects() Effects { type EffectInterface interface { GetKey() StateKey check(w world) bool - apply(data statesData) error + apply(data states) error } type Effect[T Numeric] struct { @@ -66,11 +66,11 @@ func (effect Effect[T]) check(w world) bool { return false } - k := w.data.GetIndex(effect.Key) + k := w.states.GetIndex(effect.Key) if k < 0 { return false } - s := w.data[k] + s := w.states[k] if _, ok := s.(State[T]); !ok { return false @@ -79,7 +79,7 @@ func (effect Effect[T]) check(w world) bool { return s.(State[T]).Value == effect.Value } -func (effect Effect[T]) apply(data statesData) error { +func (effect Effect[T]) apply(data states) error { k := data.GetIndex(effect.Key) if k < 0 { if slices.Contains([]arithmetic{SET, ADD}, effect.Operator) { @@ -89,7 +89,7 @@ func (effect Effect[T]) apply(data statesData) error { data = append(data, State[T]{Value: -effect.Value}) return nil } - return fmt.Errorf("data does not exist") + return fmt.Errorf("states does not exist") } if _, ok := data[k].(State[T]); !ok { return fmt.Errorf("type does not match") @@ -130,20 +130,20 @@ func (effectBool EffectBool) check(w world) bool { return false } - k := w.data.GetIndex(effectBool.Key) + k := w.states.GetIndex(effectBool.Key) if k < 0 { return false } - if _, ok := w.data[k].(State[bool]); !ok { + if _, ok := w.states[k].(State[bool]); !ok { return false } - s := w.data[k].(State[bool]) + s := w.states[k].(State[bool]) return s.Value == effectBool.Value } -func (effectBool EffectBool) apply(data statesData) error { +func (effectBool EffectBool) apply(data states) error { if effectBool.Operator != SET { return fmt.Errorf("operation %v not allowed on bool type", effectBool.Operator) } @@ -175,20 +175,20 @@ func (effectString EffectString) GetKey() StateKey { } func (effectString EffectString) check(w world) bool { - k := w.data.GetIndex(effectString.Key) + k := w.states.GetIndex(effectString.Key) if k < 0 { return false } - if _, ok := w.data[k].(State[string]); !ok { + if _, ok := w.states[k].(State[string]); !ok { return false } - s := w.data[k].(State[string]) + s := w.states[k].(State[string]) return s.Value == effectString.Value } -func (effectString EffectString) apply(data statesData) error { +func (effectString EffectString) apply(data states) error { if !slices.Contains([]arithmetic{SET, ADD}, effectString.Operator) { return fmt.Errorf("arithmetic operation %v not allowed on string type", effectString.Operator) } @@ -230,8 +230,8 @@ func (effects Effects) satisfyStates(w world) bool { return true } -func (effects Effects) apply(w world) (statesData, error) { - data := slices.Clone(w.data) +func (effects Effects) apply(w world) (states, error) { + data := slices.Clone(w.states) for _, effect := range effects { err := effect.apply(data) diff --git a/action_test.go b/action_test.go index 107829f..8b2a1cf 100644 --- a/action_test.go +++ b/action_test.go @@ -93,7 +93,7 @@ func TestEffect_Check_NonSetOperator(t *testing.T) { } func TestEffect_Apply_Set(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 100}, } @@ -110,7 +110,7 @@ func TestEffect_Apply_Set(t *testing.T) { } func TestEffect_Apply_Add(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 100}, } @@ -127,7 +127,7 @@ func TestEffect_Apply_Add(t *testing.T) { } func TestEffect_Apply_Subtract(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 100}, } @@ -144,7 +144,7 @@ func TestEffect_Apply_Subtract(t *testing.T) { } func TestEffect_Apply_Multiply(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 10}, } @@ -161,7 +161,7 @@ func TestEffect_Apply_Multiply(t *testing.T) { } func TestEffect_Apply_Divide(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 100}, } @@ -178,7 +178,7 @@ func TestEffect_Apply_Divide(t *testing.T) { } func TestEffect_Apply_NewKey(t *testing.T) { - data := statesData{} + data := states{} effect := Effect[int]{Key: 1, Value: 42, Operator: SET} err := effect.apply(data) @@ -213,7 +213,7 @@ func TestEffectBool_Check_NoMatch(t *testing.T) { } func TestEffectBool_Apply_Set(t *testing.T) { - data := statesData{ + data := states{ State[bool]{Key: 1, Value: false}, } @@ -230,7 +230,7 @@ func TestEffectBool_Apply_Set(t *testing.T) { } func TestEffectBool_Apply_InvalidOperator(t *testing.T) { - data := statesData{ + data := states{ State[bool]{Key: 1, Value: true}, } @@ -243,7 +243,7 @@ func TestEffectBool_Apply_InvalidOperator(t *testing.T) { } func TestEffectBool_Apply_NewKey(t *testing.T) { - data := statesData{} + data := states{} effect := EffectBool{Key: 1, Value: true, Operator: SET} err := effect.apply(data) @@ -278,7 +278,7 @@ func TestEffectString_Check_NoMatch(t *testing.T) { } func TestEffectString_Apply_Set(t *testing.T) { - data := statesData{ + data := states{ State[string]{Key: 1, Value: "old"}, } @@ -295,7 +295,7 @@ func TestEffectString_Apply_Set(t *testing.T) { } func TestEffectString_Apply_Add_Concatenate(t *testing.T) { - data := statesData{ + data := states{ State[string]{Key: 1, Value: "hello"}, } @@ -312,7 +312,7 @@ func TestEffectString_Apply_Add_Concatenate(t *testing.T) { } func TestEffectString_Apply_InvalidOperator(t *testing.T) { - data := statesData{ + data := states{ State[string]{Key: 1, Value: "test"}, } @@ -325,7 +325,7 @@ func TestEffectString_Apply_InvalidOperator(t *testing.T) { } func TestEffectString_Apply_NewKey(t *testing.T) { - data := statesData{} + data := states{} effect := EffectString{Key: 1, Value: "new", Operator: SET} err := effect.apply(data) diff --git a/agent.go b/agent.go index c22cb78..b5c8283 100644 --- a/agent.go +++ b/agent.go @@ -22,8 +22,8 @@ func CreateAgent(goals Goals, actions Actions) Agent { } states := world{ - Agent: &agent, - data: statesData{}, + Agent: &agent, + states: states{}, } agent.states = states @@ -31,7 +31,7 @@ func CreateAgent(goals Goals, actions Actions) Agent { } func SetState[T Numeric | bool | string](agent *Agent, key StateKey, value T) { - agent.states.data = append(agent.states.data, State[T]{ + agent.states.states = append(agent.states.states, State[T]{ Key: key, Value: value, }) diff --git a/agent_test.go b/agent_test.go index 6021331..576eb6b 100644 --- a/agent_test.go +++ b/agent_test.go @@ -48,10 +48,10 @@ func TestSetState(t *testing.T) { SetState[int](a, 1, 42) }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.data) != 1 { - t.Errorf("Expected 1 state, got %d", len(a.states.data)) + if len(a.states.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.states.states)) } - state := a.states.data[0].(State[int]) + state := a.states.states[0].(State[int]) if state.Key != 1 || state.Value != 42 { t.Errorf("Expected key=1, value=42, got key=%d, value=%d", state.Key, state.Value) } @@ -63,10 +63,10 @@ func TestSetState(t *testing.T) { SetState[bool](a, 2, true) }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.data) != 1 { - t.Errorf("Expected 1 state, got %d", len(a.states.data)) + if len(a.states.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.states.states)) } - state := a.states.data[0].(State[bool]) + state := a.states.states[0].(State[bool]) if state.Key != 2 || state.Value != true { t.Errorf("Expected key=2, value=true, got key=%d, value=%v", state.Key, state.Value) } @@ -78,10 +78,10 @@ func TestSetState(t *testing.T) { SetState[string](a, 3, "test") }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.data) != 1 { - t.Errorf("Expected 1 state, got %d", len(a.states.data)) + if len(a.states.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.states.states)) } - state := a.states.data[0].(State[string]) + state := a.states.states[0].(State[string]) if state.Key != 3 || state.Value != "test" { t.Errorf("Expected key=3, value='test', got key=%d, value='%s'", state.Key, state.Value) } @@ -95,8 +95,8 @@ func TestSetState(t *testing.T) { SetState[string](a, 3, "hello") }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.data) != 3 { - t.Errorf("Expected 3 world, got %d", len(a.states.data)) + if len(a.states.states) != 3 { + t.Errorf("Expected 3 world, got %d", len(a.states.states)) } }, }, diff --git a/astar.go b/astar.go index 7aa4a47..4eb479e 100644 --- a/astar.go +++ b/astar.go @@ -32,13 +32,13 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { nodesPool.Put(closedNodes[:0]) }() - data := slices.Clone(from.data) + data := slices.Clone(from.states) openNodes = append(openNodes, &node{ Action: &Action{}, world: world{ - Agent: from.Agent, - data: data, - hash: data.hashStates(), + Agent: from.Agent, + states: data, + hash: data.hashStates(), }, parentNode: nil, cost: 0, @@ -187,12 +187,12 @@ func simulateActionState(action *Action, w world) (world, bool) { // For each effect, we need to XOR out the old state and XOR in the new state for _, effect := range action.effects { // Find old state if it exists - oldIndex := w.data.GetIndex(effect.GetKey()) + oldIndex := w.states.GetIndex(effect.GetKey()) if oldIndex >= 0 { - newHash ^= w.data[oldIndex].Hash() // Remove old + newHash ^= w.states[oldIndex].Hash() // Remove old } - // Find new state in modified data + // Find new state in modified states newIndex := data.GetIndex(effect.GetKey()) if newIndex >= 0 { newHash ^= data[newIndex].Hash() // Add new @@ -200,9 +200,9 @@ func simulateActionState(action *Action, w world) (world, bool) { } return world{ - Agent: w.Agent, - data: data, - hash: newHash, + Agent: w.Agent, + states: data, + hash: newHash, }, true } diff --git a/astar_test.go b/astar_test.go index 3e1d5e6..210eac8 100644 --- a/astar_test.go +++ b/astar_test.go @@ -79,11 +79,11 @@ func TestGetLessCostlyNodeKey_Empty(t *testing.T) { func TestFetchNode_Found(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent1, 1, 100) - agent1.states.hash = agent1.states.data.hashStates() + agent1.states.hash = agent1.states.states.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent2, 1, 100) - agent2.states.hash = agent2.states.data.hashStates() + agent2.states.hash = agent2.states.states.hashStates() nodes := []*node{ {world: agent1.states}, @@ -102,11 +102,11 @@ func TestFetchNode_Found(t *testing.T) { func TestFetchNode_NotFound(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent1, 1, 100) - agent1.states.hash = agent1.states.data.hashStates() + agent1.states.hash = agent1.states.states.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent2, 1, 200) - agent2.states.hash = agent2.states.data.hashStates() + agent2.states.hash = agent2.states.states.hashStates() nodes := []*node{ {world: agent1.states}, @@ -173,13 +173,13 @@ func TestSimulateActionState(t *testing.T) { t.Error("Expected simulation to succeed") } - idx := newStates.data.GetIndex(1) + idx := newStates.states.GetIndex(1) if idx < 0 { t.Error("Expected to find key 1 in new world") } - if newStates.data[idx].(State[int]).Value != 150 { - t.Errorf("Expected value 150, got %d", newStates.data[idx].(State[int]).Value) + if newStates.states[idx].(State[int]).Value != 150 { + t.Errorf("Expected value 150, got %d", newStates.states[idx].(State[int]).Value) } } @@ -509,7 +509,7 @@ func TestAstar_DataCloning(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) SetState[int](&agent, 1, 100) - originalData := slices.Clone(agent.states.data) + originalData := slices.Clone(agent.states.states) actions := Actions{} actions.AddAction("modify", 1.0, false, Conditions{}, Effects{ @@ -525,10 +525,10 @@ func TestAstar_DataCloning(t *testing.T) { _ = astar(agent.states, goal, actions, 10) // Original state should not be modified - if len(agent.states.data) != len(originalData) { + if len(agent.states.states) != len(originalData) { t.Error("Original state was modified") } - if agent.states.data[0].(State[int]).Value != originalData[0].(State[int]).Value { + if agent.states.states[0].(State[int]).Value != originalData[0].(State[int]).Value { t.Error("Original state values were modified") } } diff --git a/state.go b/state.go index 40b9b3c..bea0095 100644 --- a/state.go +++ b/state.go @@ -19,8 +19,8 @@ const ( type Numeric interface { ~int8 | ~int | - ~uint8 | ~uint64 | - ~float64 + ~uint8 | ~uint64 | + ~float64 } type StateInterface interface { @@ -36,12 +36,12 @@ type State[T Numeric | bool | string] struct { type StateKey uint16 -type statesData []StateInterface +type states []StateInterface type world struct { - Agent *Agent - data statesData - hash uint64 + Agent *Agent + states states + hash uint64 } func (state State[T]) GetKey() StateKey { @@ -49,11 +49,11 @@ func (state State[T]) GetKey() StateKey { } func (state State[T]) Check(w world, key StateKey) bool { - k := w.data.GetIndex(key) + k := w.states.GetIndex(key) if k < 0 { return false } - s := w.data[k] + s := w.states[k] if agentState, ok := s.(State[T]); ok { if agentState.Value == state.Value { return true @@ -109,7 +109,7 @@ func (world world) Check(world2 world) bool { return world.hash == world2.hash } -func (statesData statesData) GetIndex(stateKey StateKey) int { +func (statesData states) GetIndex(stateKey StateKey) int { for k, stateData := range statesData { if stateData.GetKey() == stateKey { return k @@ -121,7 +121,7 @@ func (statesData statesData) GetIndex(stateKey StateKey) int { // hashStates computes the initial hash using XOR of individual state hashes // This is O(n) but only called once when creating initial state -func (statesData statesData) hashStates() uint64 { +func (statesData states) hashStates() uint64 { var hash uint64 = 0 for _, state := range statesData { hash ^= state.Hash() // XOR for incremental updates @@ -184,11 +184,11 @@ func (condition *Condition[T]) GetKey() StateKey { } func (condition *Condition[T]) Check(w world) bool { - k := w.data.GetIndex(condition.Key) + k := w.states.GetIndex(condition.Key) if k < 0 { return false } - s := w.data[k] + s := w.states[k] if state, ok := s.(State[T]); ok { switch condition.Operator { case EQUAL: @@ -232,11 +232,11 @@ func (conditionBool *ConditionBool) GetKey() StateKey { } func (conditionBool *ConditionBool) Check(w world) bool { - k := w.data.GetIndex(conditionBool.Key) + k := w.states.GetIndex(conditionBool.Key) if k < 0 { return false } - s := w.data[k] + s := w.states[k] if state, ok := s.(State[bool]); ok { switch conditionBool.Operator { case EQUAL: @@ -266,11 +266,11 @@ func (conditionString *ConditionString) GetKey() StateKey { } func (conditionString *ConditionString) Check(w world) bool { - k := w.data.GetIndex(conditionString.Key) + k := w.states.GetIndex(conditionString.Key) if k < 0 { return false } - s := w.data[k] + s := w.states[k] if state, ok := s.(State[string]); ok { switch conditionString.Operator { case EQUAL: diff --git a/state_test.go b/state_test.go index 4384fa6..ba1f75b 100644 --- a/state_test.go +++ b/state_test.go @@ -103,11 +103,11 @@ func TestStates_Check(t *testing.T) { t.Run(tt.name, func(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) tt.setup1(&agent1) - agent1.states.hash = agent1.states.data.hashStates() + agent1.states.hash = agent1.states.states.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) tt.setup2(&agent2) - agent2.states.hash = agent2.states.data.hashStates() + agent2.states.hash = agent2.states.states.hashStates() if got := agent1.states.Check(agent2.states); got != tt.wantMatch { t.Errorf("Check() = %v, want %v", got, tt.wantMatch) @@ -124,7 +124,7 @@ func TestStatesData_Operations(t *testing.T) { { name: "GetIndex found", testFunc: func(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, State[bool]{Key: 3, Value: true}, @@ -139,7 +139,7 @@ func TestStatesData_Operations(t *testing.T) { { name: "GetIndex not found", testFunc: func(t *testing.T) { - data := statesData{ + data := states{ State[int]{Key: 1, Value: 100}, } @@ -150,14 +150,14 @@ func TestStatesData_Operations(t *testing.T) { }, }, { - name: "hashStates same data", + name: "hashStates same states", testFunc: func(t *testing.T) { - data1 := statesData{ + data1 := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, } - data2 := statesData{ + data2 := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, } @@ -166,19 +166,19 @@ func TestStatesData_Operations(t *testing.T) { hash2 := data2.hashStates() if hash1 != hash2 { - t.Error("Expected identical data to produce same hash") + t.Error("Expected identical states to produce same hash") } }, }, { - name: "hashStates same data, different keys", + name: "hashStates same states, different keys", testFunc: func(t *testing.T) { - data1 := statesData{ + data1 := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, State[int]{Key: 3, Value: 300}, } - data2 := statesData{ + data2 := states{ State[int]{Key: 3, Value: 300}, State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, @@ -188,19 +188,19 @@ func TestStatesData_Operations(t *testing.T) { hash2 := data2.hashStates() if hash1 != hash2 { - t.Error("Expected identical data to produce same hash") + t.Error("Expected identical states to produce same hash") } }, }, { - name: "hashStates different data", + name: "hashStates different states", testFunc: func(t *testing.T) { - data1 := statesData{ + data1 := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, } - data2 := statesData{ + data2 := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 999}, } @@ -209,7 +209,7 @@ func TestStatesData_Operations(t *testing.T) { hash2 := data2.hashStates() if hash1 == hash2 { - t.Error("Expected different data to produce different hash") + t.Error("Expected different states to produce different hash") } }, }, @@ -600,13 +600,13 @@ func TestUpdateHashIncremental(t *testing.T) { func TestStatesData_HashStates_XORProperty(t *testing.T) { // Test that hash order doesn't matter (XOR is commutative) - data1 := statesData{ + data1 := states{ State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, State[bool]{Key: 3, Value: true}, } - data2 := statesData{ + data2 := states{ State[bool]{Key: 3, Value: true}, State[int]{Key: 1, Value: 100}, State[int]{Key: 2, Value: 200}, @@ -626,19 +626,19 @@ func TestIncrementalHashConsistency(t *testing.T) { SetState[int](&agent, 1, 100) SetState[int](&agent, 2, 200) SetState[bool](&agent, 3, true) - agent.states.hash = agent.states.data.hashStates() + agent.states.hash = agent.states.states.hashStates() initialHash := agent.states.hash // Modify a state and calculate hash incrementally - oldState := agent.states.data[0] + oldState := agent.states.states[0] newState := State[int]{Key: 1, Value: 150} incrementalHash := updateHashIncremental(agent.states.hash, oldState, newState) // Modify the same state and calculate hash from scratch - agent.states.data[0] = newState - fullHash := agent.states.data.hashStates() + agent.states.states[0] = newState + fullHash := agent.states.states.hashStates() if incrementalHash != fullHash { t.Errorf("Incremental hash (%d) doesn't match full recalculation (%d)", incrementalHash, fullHash) From 3cbe57f2a1843b524832d2493f04cd99b831d822 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 00:39:30 +0100 Subject: [PATCH 06/18] feat: wip perfs --- action.go | 54 ++++++++++++++++++------------------- action_test.go | 22 +++++++-------- agent.go | 6 ++--- agent_test.go | 24 ++++++++--------- astar.go | 32 ++++------------------ astar_test.go | 58 ++++++++++++++++++++-------------------- benchmark/goapai_test.go | 36 ++++++++++++++----------- planer.go | 6 ++++- state.go | 42 ++++++++++++++++------------- state_test.go | 52 +++++++++++++++++------------------ 10 files changed, 163 insertions(+), 169 deletions(-) diff --git a/action.go b/action.go index 5bda5e4..38fd5eb 100644 --- a/action.go +++ b/action.go @@ -47,7 +47,7 @@ func (action *Action) GetEffects() Effects { type EffectInterface interface { GetKey() StateKey check(w world) bool - apply(data states) error + apply(w *world) error } type Effect[T Numeric] struct { @@ -79,23 +79,23 @@ func (effect Effect[T]) check(w world) bool { return s.(State[T]).Value == effect.Value } -func (effect Effect[T]) apply(data states) error { - k := data.GetIndex(effect.Key) +func (effect Effect[T]) apply(w *world) error { + k := w.states.GetIndex(effect.Key) if k < 0 { if slices.Contains([]arithmetic{SET, ADD}, effect.Operator) { - data = append(data, State[T]{Value: effect.Value}) + w.states = append(w.states, State[T]{Value: effect.Value}) return nil } else if slices.Contains([]arithmetic{SUBSTRACT}, effect.Operator) { - data = append(data, State[T]{Value: -effect.Value}) + w.states = append(w.states, State[T]{Value: -effect.Value}) return nil } - return fmt.Errorf("states does not exist") + return fmt.Errorf("w does not exist") } - if _, ok := data[k].(State[T]); !ok { + if _, ok := w.states[k].(State[T]); !ok { return fmt.Errorf("type does not match") } - state := data[k].(State[T]) + state := w.states[k].(State[T]) switch effect.Operator { case SET: state.Value = effect.Value @@ -109,7 +109,7 @@ func (effect Effect[T]) apply(data states) error { state.Value /= effect.Value } - data[k] = state + state.Store(w) return nil } @@ -143,23 +143,24 @@ func (effectBool EffectBool) check(w world) bool { return s.Value == effectBool.Value } -func (effectBool EffectBool) apply(data states) error { +func (effectBool EffectBool) apply(w *world) error { if effectBool.Operator != SET { return fmt.Errorf("operation %v not allowed on bool type", effectBool.Operator) } - k := data.GetIndex(effectBool.Key) + k := w.states.GetIndex(effectBool.Key) if k < 0 { - data = append(data, State[bool]{Value: effectBool.Value}) + w.states = append(w.states, State[bool]{Value: effectBool.Value}) return nil } - if _, ok := data[k].(State[bool]); !ok { + if _, ok := w.states[k].(State[bool]); !ok { return fmt.Errorf("type does not match") } - state := data[k].(State[bool]) + state := w.states[k].(State[bool]) state.Value = effectBool.Value - data[k] = state + + state.Store(w) return nil } @@ -188,28 +189,29 @@ func (effectString EffectString) check(w world) bool { return s.Value == effectString.Value } -func (effectString EffectString) apply(data states) error { +func (effectString EffectString) apply(w *world) error { if !slices.Contains([]arithmetic{SET, ADD}, effectString.Operator) { return fmt.Errorf("arithmetic operation %v not allowed on string type", effectString.Operator) } - k := data.GetIndex(effectString.Key) + k := w.states.GetIndex(effectString.Key) if k < 0 { - data = append(data, State[string]{Value: effectString.Value}) + w.states = append(w.states, State[string]{Value: effectString.Value}) return nil } - if _, ok := data[k].(State[string]); !ok { + if _, ok := w.states[k].(State[string]); !ok { return fmt.Errorf("type does not match") } - state := data[k].(State[string]) + state := w.states[k].(State[string]) switch effectString.Operator { case SET: state.Value = effectString.Value case ADD: state.Value = fmt.Sprint(state.Value, effectString.Value) } - data[k] = state + + state.Store(w) return nil } @@ -230,16 +232,14 @@ func (effects Effects) satisfyStates(w world) bool { return true } -func (effects Effects) apply(w world) (states, error) { - data := slices.Clone(w.states) - +func (effects Effects) apply(w *world) error { for _, effect := range effects { - err := effect.apply(data) + err := effect.apply(w) if err != nil { - return nil, err + return err } } - return data, nil + return nil } diff --git a/action_test.go b/action_test.go index 8b2a1cf..7f8fc78 100644 --- a/action_test.go +++ b/action_test.go @@ -67,7 +67,7 @@ func TestEffect_Check_Match(t *testing.T) { SetState[int](&agent, 1, 100) effect := Effect[int]{Key: 1, Value: 100, Operator: SET} - if !effect.check(agent.states) { + if !effect.check(agent.w) { t.Error("Expected effect to match state") } } @@ -77,7 +77,7 @@ func TestEffect_Check_NoMatch(t *testing.T) { SetState[int](&agent, 1, 100) effect := Effect[int]{Key: 1, Value: 200, Operator: SET} - if effect.check(agent.states) { + if effect.check(agent.w) { t.Error("Expected effect not to match state") } } @@ -87,7 +87,7 @@ func TestEffect_Check_NonSetOperator(t *testing.T) { SetState[int](&agent, 1, 100) effect := Effect[int]{Key: 1, Value: 100, Operator: ADD} - if effect.check(agent.states) { + if effect.check(agent.w) { t.Error("Expected non-SET operator to always return false") } } @@ -197,7 +197,7 @@ func TestEffectBool_Check_Match(t *testing.T) { SetState[bool](&agent, 1, true) effect := EffectBool{Key: 1, Value: true, Operator: SET} - if !effect.check(agent.states) { + if !effect.check(agent.w) { t.Error("Expected effect to match state") } } @@ -207,7 +207,7 @@ func TestEffectBool_Check_NoMatch(t *testing.T) { SetState[bool](&agent, 1, true) effect := EffectBool{Key: 1, Value: false, Operator: SET} - if effect.check(agent.states) { + if effect.check(agent.w) { t.Error("Expected effect not to match state") } } @@ -262,7 +262,7 @@ func TestEffectString_Check_Match(t *testing.T) { SetState[string](&agent, 1, "test") effect := EffectString{Key: 1, Value: "test", Operator: SET} - if !effect.check(agent.states) { + if !effect.check(agent.w) { t.Error("Expected effect to match state") } } @@ -272,7 +272,7 @@ func TestEffectString_Check_NoMatch(t *testing.T) { SetState[string](&agent, 1, "test") effect := EffectString{Key: 1, Value: "other", Operator: SET} - if effect.check(agent.states) { + if effect.check(agent.w) { t.Error("Expected effect not to match state") } } @@ -349,7 +349,7 @@ func TestEffects_SatisfyStates_AllMatch(t *testing.T) { EffectBool{Key: 2, Value: true, Operator: SET}, } - if !effects.satisfyStates(agent.states) { + if !effects.satisfyStates(agent.w) { t.Error("Expected effects to satisfy world") } } @@ -364,7 +364,7 @@ func TestEffects_SatisfyStates_OneFails(t *testing.T) { EffectBool{Key: 2, Value: false, Operator: SET}, } - if effects.satisfyStates(agent.states) { + if effects.satisfyStates(agent.w) { t.Error("Expected effects not to satisfy world when one doesn't match") } } @@ -379,7 +379,7 @@ func TestEffects_Apply(t *testing.T) { EffectBool{Key: 2, Value: true, Operator: SET}, } - newData, err := effects.apply(agent.states) + newData, err := effects.apply(agent.w) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -415,7 +415,7 @@ func TestEffects_Apply_Error(t *testing.T) { Effect[float64]{Key: 1, Value: 50.0, Operator: SET}, // Type mismatch } - _, err := effects.apply(agent.states) + _, err := effects.apply(agent.w) if err == nil { t.Error("Expected error for type mismatch") } diff --git a/agent.go b/agent.go index b5c8283..8ad3db7 100644 --- a/agent.go +++ b/agent.go @@ -2,7 +2,7 @@ package goapai type Agent struct { actions Actions - states world + w world sensors Sensors goals Goals } @@ -25,13 +25,13 @@ func CreateAgent(goals Goals, actions Actions) Agent { Agent: &agent, states: states{}, } - agent.states = states + agent.w = states return agent } func SetState[T Numeric | bool | string](agent *Agent, key StateKey, value T) { - agent.states.states = append(agent.states.states, State[T]{ + agent.w.states = append(agent.w.states, State[T]{ Key: key, Value: value, }) diff --git a/agent_test.go b/agent_test.go index 576eb6b..808bc2f 100644 --- a/agent_test.go +++ b/agent_test.go @@ -31,7 +31,7 @@ func TestCreateAgent(t *testing.T) { t.Error("Expected sensors to be initialized") } - if agent.states.Agent == nil { + if agent.w.Agent == nil { t.Error("Expected world.Agent to be non-nil") } } @@ -48,10 +48,10 @@ func TestSetState(t *testing.T) { SetState[int](a, 1, 42) }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.states) != 1 { - t.Errorf("Expected 1 state, got %d", len(a.states.states)) + if len(a.w.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.w.states)) } - state := a.states.states[0].(State[int]) + state := a.w.states[0].(State[int]) if state.Key != 1 || state.Value != 42 { t.Errorf("Expected key=1, value=42, got key=%d, value=%d", state.Key, state.Value) } @@ -63,10 +63,10 @@ func TestSetState(t *testing.T) { SetState[bool](a, 2, true) }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.states) != 1 { - t.Errorf("Expected 1 state, got %d", len(a.states.states)) + if len(a.w.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.w.states)) } - state := a.states.states[0].(State[bool]) + state := a.w.states[0].(State[bool]) if state.Key != 2 || state.Value != true { t.Errorf("Expected key=2, value=true, got key=%d, value=%v", state.Key, state.Value) } @@ -78,10 +78,10 @@ func TestSetState(t *testing.T) { SetState[string](a, 3, "test") }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.states) != 1 { - t.Errorf("Expected 1 state, got %d", len(a.states.states)) + if len(a.w.states) != 1 { + t.Errorf("Expected 1 state, got %d", len(a.w.states)) } - state := a.states.states[0].(State[string]) + state := a.w.states[0].(State[string]) if state.Key != 3 || state.Value != "test" { t.Errorf("Expected key=3, value='test', got key=%d, value='%s'", state.Key, state.Value) } @@ -95,8 +95,8 @@ func TestSetState(t *testing.T) { SetState[string](a, 3, "hello") }, checkFunc: func(t *testing.T, a Agent) { - if len(a.states.states) != 3 { - t.Errorf("Expected 3 world, got %d", len(a.states.states)) + if len(a.w.states) != 3 { + t.Errorf("Expected 3 world, got %d", len(a.w.states)) } }, }, diff --git a/astar.go b/astar.go index 4eb479e..87d7b8a 100644 --- a/astar.go +++ b/astar.go @@ -32,13 +32,12 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { nodesPool.Put(closedNodes[:0]) }() - data := slices.Clone(from.states) openNodes = append(openNodes, &node{ Action: &Action{}, world: world{ Agent: from.Agent, - states: data, - hash: data.hashStates(), + states: slices.Clone(from.states), + hash: from.hash, }, parentNode: nil, cost: 0, @@ -176,34 +175,13 @@ func simulateActionState(action *Action, w world) (world, bool) { return world{}, false } - data, err := action.effects.apply(w) + w.states = slices.Clone(w.states) + err := action.effects.apply(&w) if err != nil { return world{}, false } - // Calculate hash incrementally by tracking changes - newHash := w.hash - - // For each effect, we need to XOR out the old state and XOR in the new state - for _, effect := range action.effects { - // Find old state if it exists - oldIndex := w.states.GetIndex(effect.GetKey()) - if oldIndex >= 0 { - newHash ^= w.states[oldIndex].Hash() // Remove old - } - - // Find new state in modified states - newIndex := data.GetIndex(effect.GetKey()) - if newIndex >= 0 { - newHash ^= data[newIndex].Hash() // Add new - } - } - - return world{ - Agent: w.Agent, - states: data, - hash: newHash, - }, true + return w, true } func allowedRepetition(action *Action, parentNode *node) bool { diff --git a/astar_test.go b/astar_test.go index 210eac8..0bfe4e1 100644 --- a/astar_test.go +++ b/astar_test.go @@ -22,7 +22,7 @@ func TestGetImpactingActions(t *testing.T) { EffectBool{Key: 2, Value: false, Operator: SET}, }) - impacting := getImpactingActions(agent.states, actions) + impacting := getImpactingActions(agent.w, actions) if len(impacting) != 1 { t.Errorf("Expected 1 impacting action, got %d", len(impacting)) @@ -45,7 +45,7 @@ func TestGetImpactingActions_AllImpacting(t *testing.T) { Effect[int]{Key: 1, Value: 50, Operator: SET}, }) - impacting := getImpactingActions(agent.states, actions) + impacting := getImpactingActions(agent.w, actions) if len(impacting) != 2 { t.Errorf("Expected 2 impacting actions, got %d", len(impacting)) @@ -79,18 +79,18 @@ func TestGetLessCostlyNodeKey_Empty(t *testing.T) { func TestFetchNode_Found(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent1, 1, 100) - agent1.states.hash = agent1.states.states.hashStates() + agent1.w.hash = agent1.w.states.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent2, 1, 100) - agent2.states.hash = agent2.states.states.hashStates() + agent2.w.hash = agent2.w.states.hashStates() nodes := []*node{ - {world: agent1.states}, - {world: agent2.states}, + {world: agent1.w}, + {world: agent2.w}, } - key, found := fetchNode(nodes, agent1.states) + key, found := fetchNode(nodes, agent1.w) if !found { t.Error("Expected to find node") } @@ -102,17 +102,17 @@ func TestFetchNode_Found(t *testing.T) { func TestFetchNode_NotFound(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent1, 1, 100) - agent1.states.hash = agent1.states.states.hashStates() + agent1.w.hash = agent1.w.states.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) SetState[int](&agent2, 1, 200) - agent2.states.hash = agent2.states.states.hashStates() + agent2.w.hash = agent2.w.states.hashStates() nodes := []*node{ - {world: agent1.states}, + {world: agent1.w}, } - _, found := fetchNode(nodes, agent2.states) + _, found := fetchNode(nodes, agent2.w) if found { t.Error("Expected not to find node") } @@ -168,7 +168,7 @@ func TestSimulateActionState(t *testing.T) { }, } - newStates, ok := simulateActionState(action, agent.states) + newStates, ok := simulateActionState(action, agent.w) if !ok { t.Error("Expected simulation to succeed") } @@ -193,7 +193,7 @@ func TestSimulateActionState_NoChange(t *testing.T) { }, } - _, ok := simulateActionState(action, agent.states) + _, ok := simulateActionState(action, agent.w) if ok { t.Error("Expected simulation to fail when effects match current state") } @@ -244,7 +244,7 @@ func TestCountMissingGoal_AllMet(t *testing.T) { }, } - count := countMissingGoal(goal, agent.states) + count := countMissingGoal(goal, agent.w) if count != 0 { t.Errorf("Expected 0 missing goals, got %d", count) } @@ -262,7 +262,7 @@ func TestCountMissingGoal_OneMissing(t *testing.T) { }, } - count := countMissingGoal(goal, agent.states) + count := countMissingGoal(goal, agent.w) if count != 1 { t.Errorf("Expected 1 missing goal, got %d", count) } @@ -280,7 +280,7 @@ func TestCountMissingGoal_AllMissing(t *testing.T) { }, } - count := countMissingGoal(goal, agent.states) + count := countMissingGoal(goal, agent.w) if count != 2 { t.Errorf("Expected 2 missing goals, got %d", count) } @@ -300,7 +300,7 @@ func TestComputeHeuristic(t *testing.T) { }, } - heuristic := computeHeuristic(fromAgent.states, goal, currentAgent.states) + heuristic := computeHeuristic(fromAgent.w, goal, currentAgent.w) if heuristic <= 0 { t.Error("Expected positive heuristic for unmet goal") } @@ -319,7 +319,7 @@ func TestComputeHeuristic_GoalMet(t *testing.T) { }, } - heuristic := computeHeuristic(fromAgent.states, goal, currentAgent.states) + heuristic := computeHeuristic(fromAgent.w, goal, currentAgent.w) if heuristic != 0 { t.Errorf("Expected 0 heuristic for met goal, got %f", heuristic) } @@ -341,7 +341,7 @@ func TestAstar_SimpleGoal(t *testing.T) { }, } - plan := astar(agent.states, goal, actions, 10) + plan := astar(agent.w, goal, actions, 10) // Plan includes the root node with empty action if len(plan) != 4 { @@ -362,7 +362,7 @@ func TestAstar_UnreachableGoal(t *testing.T) { }, } - plan := astar(agent.states, goal, actions, 10) + plan := astar(agent.w, goal, actions, 10) if len(plan) != 0 { t.Errorf("Expected empty plan for unreachable goal, got %d actions", len(plan)) @@ -385,7 +385,7 @@ func TestAstar_MaxDepth(t *testing.T) { } // Max depth of 5 should prevent reaching goal of 100 - plan := astar(agent.states, goal, actions, 5) + plan := astar(agent.w, goal, actions, 5) if len(plan) > 5 { t.Errorf("Expected plan to respect maxDepth of 5, got %d actions", len(plan)) @@ -404,7 +404,7 @@ func TestAstar_AlreadyAtGoal(t *testing.T) { }, } - plan := astar(agent.states, goal, actions, 10) + plan := astar(agent.w, goal, actions, 10) // Plan includes root node with empty action when already at goal if len(plan) != 1 { @@ -430,7 +430,7 @@ func TestAstar_PreferLowerCost(t *testing.T) { }, } - plan := astar(agent.states, goal, actions, 10) + plan := astar(agent.w, goal, actions, 10) // Plan includes root + action if len(plan) != 2 { @@ -465,7 +465,7 @@ func TestAstar_RespectConditions(t *testing.T) { }, } - plan := astar(agent.states, goal, actions, 10) + plan := astar(agent.w, goal, actions, 10) // Plan includes root + 2 actions if len(plan) != 3 { @@ -497,7 +497,7 @@ func TestAstar_NonRepeatableActions(t *testing.T) { }, } - plan := astar(agent.states, goal, actions, 10) + plan := astar(agent.w, goal, actions, 10) // Non-repeatable action can only be used once, so goal is unreachable if len(plan) != 0 { @@ -509,7 +509,7 @@ func TestAstar_DataCloning(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) SetState[int](&agent, 1, 100) - originalData := slices.Clone(agent.states.states) + originalData := slices.Clone(agent.w.states) actions := Actions{} actions.AddAction("modify", 1.0, false, Conditions{}, Effects{ @@ -522,13 +522,13 @@ func TestAstar_DataCloning(t *testing.T) { }, } - _ = astar(agent.states, goal, actions, 10) + _ = astar(agent.w, goal, actions, 10) // Original state should not be modified - if len(agent.states.states) != len(originalData) { + if len(agent.w.states) != len(originalData) { t.Error("Original state was modified") } - if agent.states.states[0].(State[int]).Value != originalData[0].(State[int]).Value { + if agent.w.states[0].(State[int]).Value != originalData[0].(State[int]).Value { t.Error("Original state values were modified") } } diff --git a/benchmark/goapai_test.go b/benchmark/goapai_test.go index 0a00426..89a8ed9 100644 --- a/benchmark/goapai_test.go +++ b/benchmark/goapai_test.go @@ -3,6 +3,9 @@ package benchmark import ( "fmt" "goapai" + "os" + "runtime/pprof" + "runtime/trace" "testing" ) @@ -67,28 +70,31 @@ func BenchmarkGoapAI(b *testing.B) { goapai.SetState[int](&entity.agent, ATTRIBUTE_1, 80) goapai.SetState[int](&entity.agent, ATTRIBUTE_3, 0) - // Write to the trace file. - //f, _ := os.Create("trace.out") - //fcpu, _ := os.Create(`cpu.prof`) - //fheap, _ := os.Create(`heap.prof`) - // - //pprof.StartCPUProfile(fcpu) - //pprof.WriteHeapProfile(fheap) - //trace.Start(f) + //Write to the trace file. + f, _ := os.Create("trace.out") + fcpu, _ := os.Create(`cpu.prof`) + fheap, _ := os.Create(`heap.prof`) + + pprof.StartCPUProfile(fcpu) + pprof.WriteHeapProfile(fheap) + trace.Start(f) var lastPlan goapai.Plan for b.Loop() { //goapai.GetPlan(entity.agent, 15) _, lastPlan = goapai.GetPlan(entity.agent, 15) } - fmt.Println(len(lastPlan)) + fmt.Println("Actions for plan", len(lastPlan)) + for _, j := range lastPlan { + fmt.Printf(" - %v\n", j.GetName()) + } + + defer f.Close() + defer fcpu.Close() + defer fheap.Close() - //defer f.Close() - //defer fcpu.Close() - //defer fheap.Close() - // - //trace.Stop() - //pprof.StopCPUProfile() + trace.Stop() + pprof.StopCPUProfile() b.ReportAllocs() } diff --git a/planer.go b/planer.go index 8f8e1bf..ff283da 100644 --- a/planer.go +++ b/planer.go @@ -32,7 +32,11 @@ func GetPlan(agent Agent, maxDepth int) (GoalName, Plan) { return "", Plan{} } - return goalName, astar(agent.states, agent.goals[goalName], agent.actions, maxDepth) + for _, state := range agent.w.states { + state.Store(&agent.w) + } + + return goalName, astar(agent.w, agent.goals[goalName], agent.actions, maxDepth) } func (agent *Agent) getPrioritizedGoalName() (GoalName, error) { diff --git a/state.go b/state.go index bea0095..5f35c75 100644 --- a/state.go +++ b/state.go @@ -27,11 +27,14 @@ type StateInterface interface { Check(w world, key StateKey) bool GetKey() StateKey GetValue() any + Store(w *world) + GetHash() uint64 Hash() uint64 } type State[T Numeric | bool | string] struct { Key StateKey Value T + hash uint64 } type StateKey uint16 @@ -44,6 +47,12 @@ type world struct { hash uint64 } +func createState[T Numeric | bool | string](s State[T]) State[T] { + s.hash = s.Hash() + + return s +} + func (state State[T]) GetKey() StateKey { return state.Key } @@ -67,6 +76,17 @@ func (state State[T]) GetValue() any { return state.Value } +func (state State[T]) Store(w *world) { + oldHash := state.hash + state.hash = state.Hash() + w.hash = updateHashIncremental(w.hash, oldHash, state.hash) + w.states[state.Key] = state +} + +func (state State[T]) GetHash() uint64 { + return state.hash +} + // Hash returns a unique hash for this state using FNV-64a func (state State[T]) Hash() uint64 { h := fnv.New64a() @@ -119,25 +139,11 @@ func (statesData states) GetIndex(stateKey StateKey) int { return -1 } -// hashStates computes the initial hash using XOR of individual state hashes -// This is O(n) but only called once when creating initial state -func (statesData states) hashStates() uint64 { - var hash uint64 = 0 - for _, state := range statesData { - hash ^= state.Hash() // XOR for incremental updates - } - return hash -} - // updateHashIncremental updates a hash by removing old state and adding new state -// This is O(1) - the key optimization -func updateHashIncremental(currentHash uint64, oldState, newState StateInterface) uint64 { - if oldState != nil { - currentHash ^= oldState.Hash() // Remove old - } - if newState != nil { - currentHash ^= newState.Hash() // Add new - } +func updateHashIncremental(currentHash uint64, oldStateHash, newStateHash uint64) uint64 { + currentHash ^= oldStateHash // Remove old + currentHash ^= newStateHash // Add new + return currentHash } diff --git a/state_test.go b/state_test.go index ba1f75b..219ffbe 100644 --- a/state_test.go +++ b/state_test.go @@ -32,7 +32,7 @@ func TestState_Operations(t *testing.T) { SetState[int](&agent, 1, 100) state := State[int]{Key: 1, Value: 100} - if !state.Check(agent.states, 1) { + if !state.Check(agent.w, 1) { t.Error("Expected state to match") } }, @@ -44,7 +44,7 @@ func TestState_Operations(t *testing.T) { SetState[int](&agent, 1, 100) wrongState := State[int]{Key: 1, Value: 200} - if wrongState.Check(agent.states, 1) { + if wrongState.Check(agent.w, 1) { t.Error("Expected state not to match") } }, @@ -55,7 +55,7 @@ func TestState_Operations(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) state := State[int]{Key: 99, Value: 100} - if state.Check(agent.states, 99) { + if state.Check(agent.w, 99) { t.Error("Expected false for non-existent key") } }, @@ -103,13 +103,13 @@ func TestStates_Check(t *testing.T) { t.Run(tt.name, func(t *testing.T) { agent1 := CreateAgent(Goals{}, Actions{}) tt.setup1(&agent1) - agent1.states.hash = agent1.states.states.hashStates() + agent1.w.hash = agent1.w.states.hashStates() agent2 := CreateAgent(Goals{}, Actions{}) tt.setup2(&agent2) - agent2.states.hash = agent2.states.states.hashStates() + agent2.w.hash = agent2.w.states.hashStates() - if got := agent1.states.Check(agent2.states); got != tt.wantMatch { + if got := agent1.w.Check(agent2.w); got != tt.wantMatch { t.Errorf("Check() = %v, want %v", got, tt.wantMatch) } }) @@ -150,7 +150,7 @@ func TestStatesData_Operations(t *testing.T) { }, }, { - name: "hashStates same states", + name: "hashStates same w", testFunc: func(t *testing.T) { data1 := states{ State[int]{Key: 1, Value: 100}, @@ -166,12 +166,12 @@ func TestStatesData_Operations(t *testing.T) { hash2 := data2.hashStates() if hash1 != hash2 { - t.Error("Expected identical states to produce same hash") + t.Error("Expected identical w to produce same hash") } }, }, { - name: "hashStates same states, different keys", + name: "hashStates same w, different keys", testFunc: func(t *testing.T) { data1 := states{ State[int]{Key: 1, Value: 100}, @@ -188,12 +188,12 @@ func TestStatesData_Operations(t *testing.T) { hash2 := data2.hashStates() if hash1 != hash2 { - t.Error("Expected identical states to produce same hash") + t.Error("Expected identical w to produce same hash") } }, }, { - name: "hashStates different states", + name: "hashStates different w", testFunc: func(t *testing.T) { data1 := states{ State[int]{Key: 1, Value: 100}, @@ -209,7 +209,7 @@ func TestStatesData_Operations(t *testing.T) { hash2 := data2.hashStates() if hash1 == hash2 { - t.Error("Expected different states to produce different hash") + t.Error("Expected different w to produce different hash") } }, }, @@ -254,7 +254,7 @@ func TestCondition_Operators(t *testing.T) { SetState[int](&agent, 1, tt.stateVal) condition := Condition[int]{Key: 1, Value: tt.condVal, Operator: tt.operator} - if got := condition.Check(agent.states); got != tt.wantMatch { + if got := condition.Check(agent.w); got != tt.wantMatch { t.Errorf("Check() = %v, want %v for state=%d, cond=%d, op=%v", got, tt.wantMatch, tt.stateVal, tt.condVal, tt.operator) } @@ -266,7 +266,7 @@ func TestCondition_KeyNotFound(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) condition := Condition[int]{Key: 99, Value: 100, Operator: EQUAL} - if condition.Check(agent.states) { + if condition.Check(agent.w) { t.Error("Expected condition to fail when key not found") } } @@ -292,7 +292,7 @@ func TestConditionBool(t *testing.T) { SetState[bool](&agent, 1, tt.stateVal) condition := ConditionBool{Key: 1, Value: tt.condVal, Operator: tt.operator} - if got := condition.Check(agent.states); got != tt.wantMatch { + if got := condition.Check(agent.w); got != tt.wantMatch { t.Errorf("Check() = %v, want %v", got, tt.wantMatch) } }) @@ -302,7 +302,7 @@ func TestConditionBool(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) condition := ConditionBool{Key: 99, Value: true, Operator: EQUAL} - if condition.Check(agent.states) { + if condition.Check(agent.w) { t.Error("Expected condition to fail when key not found") } }) @@ -328,7 +328,7 @@ func TestConditionString(t *testing.T) { SetState[string](&agent, 1, tt.stateVal) condition := ConditionString{Key: 1, Value: tt.condVal, Operator: tt.operator} - if got := condition.Check(agent.states); got != tt.wantMatch { + if got := condition.Check(agent.w); got != tt.wantMatch { t.Errorf("Check() = %v, want %v", got, tt.wantMatch) } }) @@ -359,7 +359,7 @@ func TestConditionFn(t *testing.T) { }, } - if got := condition.Check(agent.states); got != tt.wantResult { + if got := condition.Check(agent.w); got != tt.wantResult { t.Errorf("Check() = %v, want %v", got, tt.wantResult) } @@ -369,7 +369,7 @@ func TestConditionFn(t *testing.T) { } // Call again to test cache - if got := condition.Check(agent.states); got != tt.wantResult { + if got := condition.Check(agent.w); got != tt.wantResult { t.Error("Expected cached result to match") } }) @@ -420,7 +420,7 @@ func TestConditions_Check(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) tt.setup(&agent) - if got := tt.conditions.Check(agent.states); got != tt.wantMatch { + if got := tt.conditions.Check(agent.w); got != tt.wantMatch { t.Errorf("Check() = %v, want %v", got, tt.wantMatch) } }) @@ -626,19 +626,19 @@ func TestIncrementalHashConsistency(t *testing.T) { SetState[int](&agent, 1, 100) SetState[int](&agent, 2, 200) SetState[bool](&agent, 3, true) - agent.states.hash = agent.states.states.hashStates() + agent.w.hash = agent.w.states.hashStates() - initialHash := agent.states.hash + initialHash := agent.w.hash // Modify a state and calculate hash incrementally - oldState := agent.states.states[0] + oldState := agent.w.states[0] newState := State[int]{Key: 1, Value: 150} - incrementalHash := updateHashIncremental(agent.states.hash, oldState, newState) + incrementalHash := updateHashIncremental(agent.w.hash, oldState, newState) // Modify the same state and calculate hash from scratch - agent.states.states[0] = newState - fullHash := agent.states.states.hashStates() + agent.w.states[0] = newState + fullHash := agent.w.states.hashStates() if incrementalHash != fullHash { t.Errorf("Incremental hash (%d) doesn't match full recalculation (%d)", incrementalHash, fullHash) From 913ccb08543687663665e2b3b6af6494388e3e13 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 10:55:05 +0100 Subject: [PATCH 07/18] feat: use multiplicative hash for perfs --- action_test.go | 329 ------------------------------------------------- astar_test.go | 43 ------- state.go | 50 ++++---- state_test.go | 274 ---------------------------------------- 4 files changed, 27 insertions(+), 669 deletions(-) diff --git a/action_test.go b/action_test.go index 7f8fc78..c02a970 100644 --- a/action_test.go +++ b/action_test.go @@ -91,332 +91,3 @@ func TestEffect_Check_NonSetOperator(t *testing.T) { t.Error("Expected non-SET operator to always return false") } } - -func TestEffect_Apply_Set(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 100}, - } - - effect := Effect[int]{Key: 1, Value: 200, Operator: SET} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[int]).Value != 200 { - t.Errorf("Expected value 200, got %d", data[0].(State[int]).Value) - } -} - -func TestEffect_Apply_Add(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 100}, - } - - effect := Effect[int]{Key: 1, Value: 50, Operator: ADD} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[int]).Value != 150 { - t.Errorf("Expected value 150, got %d", data[0].(State[int]).Value) - } -} - -func TestEffect_Apply_Subtract(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 100}, - } - - effect := Effect[int]{Key: 1, Value: 30, Operator: SUBSTRACT} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[int]).Value != 70 { - t.Errorf("Expected value 70, got %d", data[0].(State[int]).Value) - } -} - -func TestEffect_Apply_Multiply(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 10}, - } - - effect := Effect[int]{Key: 1, Value: 5, Operator: MULTIPLY} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[int]).Value != 50 { - t.Errorf("Expected value 50, got %d", data[0].(State[int]).Value) - } -} - -func TestEffect_Apply_Divide(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 100}, - } - - effect := Effect[int]{Key: 1, Value: 4, Operator: DIVIDE} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[int]).Value != 25 { - t.Errorf("Expected value 25, got %d", data[0].(State[int]).Value) - } -} - -func TestEffect_Apply_NewKey(t *testing.T) { - data := states{} - - effect := Effect[int]{Key: 1, Value: 42, Operator: SET} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Note: apply modifies the slice in place but the original reference doesn't change - // This is actually expected behavior - effects.apply() returns a new slice -} - -// Test EffectBool -func TestEffectBool_Check_Match(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[bool](&agent, 1, true) - - effect := EffectBool{Key: 1, Value: true, Operator: SET} - if !effect.check(agent.w) { - t.Error("Expected effect to match state") - } -} - -func TestEffectBool_Check_NoMatch(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[bool](&agent, 1, true) - - effect := EffectBool{Key: 1, Value: false, Operator: SET} - if effect.check(agent.w) { - t.Error("Expected effect not to match state") - } -} - -func TestEffectBool_Apply_Set(t *testing.T) { - data := states{ - State[bool]{Key: 1, Value: false}, - } - - effect := EffectBool{Key: 1, Value: true, Operator: SET} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[bool]).Value != true { - t.Error("Expected value to be true") - } -} - -func TestEffectBool_Apply_InvalidOperator(t *testing.T) { - data := states{ - State[bool]{Key: 1, Value: true}, - } - - effect := EffectBool{Key: 1, Value: false, Operator: ADD} - err := effect.apply(data) - - if err == nil { - t.Error("Expected error for invalid operator on bool") - } -} - -func TestEffectBool_Apply_NewKey(t *testing.T) { - data := states{} - - effect := EffectBool{Key: 1, Value: true, Operator: SET} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Note: apply modifies the slice in place but the original reference doesn't change - // This is actually expected behavior - effects.apply() returns a new slice -} - -// Test EffectString -func TestEffectString_Check_Match(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[string](&agent, 1, "test") - - effect := EffectString{Key: 1, Value: "test", Operator: SET} - if !effect.check(agent.w) { - t.Error("Expected effect to match state") - } -} - -func TestEffectString_Check_NoMatch(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[string](&agent, 1, "test") - - effect := EffectString{Key: 1, Value: "other", Operator: SET} - if effect.check(agent.w) { - t.Error("Expected effect not to match state") - } -} - -func TestEffectString_Apply_Set(t *testing.T) { - data := states{ - State[string]{Key: 1, Value: "old"}, - } - - effect := EffectString{Key: 1, Value: "new", Operator: SET} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[string]).Value != "new" { - t.Errorf("Expected value 'new', got '%s'", data[0].(State[string]).Value) - } -} - -func TestEffectString_Apply_Add_Concatenate(t *testing.T) { - data := states{ - State[string]{Key: 1, Value: "hello"}, - } - - effect := EffectString{Key: 1, Value: " world", Operator: ADD} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if data[0].(State[string]).Value != "hello world" { - t.Errorf("Expected 'hello world', got '%s'", data[0].(State[string]).Value) - } -} - -func TestEffectString_Apply_InvalidOperator(t *testing.T) { - data := states{ - State[string]{Key: 1, Value: "test"}, - } - - effect := EffectString{Key: 1, Value: "x", Operator: MULTIPLY} - err := effect.apply(data) - - if err == nil { - t.Error("Expected error for invalid operator on string") - } -} - -func TestEffectString_Apply_NewKey(t *testing.T) { - data := states{} - - effect := EffectString{Key: 1, Value: "new", Operator: SET} - err := effect.apply(data) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Note: apply modifies the slice in place but the original reference doesn't change - // This is actually expected behavior - effects.apply() returns a new slice -} - -// Test Effects (slice) operations -func TestEffects_SatisfyStates_AllMatch(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - SetState[bool](&agent, 2, true) - - effects := Effects{ - Effect[int]{Key: 1, Value: 100, Operator: SET}, - EffectBool{Key: 2, Value: true, Operator: SET}, - } - - if !effects.satisfyStates(agent.w) { - t.Error("Expected effects to satisfy world") - } -} - -func TestEffects_SatisfyStates_OneFails(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - SetState[bool](&agent, 2, true) - - effects := Effects{ - Effect[int]{Key: 1, Value: 100, Operator: SET}, - EffectBool{Key: 2, Value: false, Operator: SET}, - } - - if effects.satisfyStates(agent.w) { - t.Error("Expected effects not to satisfy world when one doesn't match") - } -} - -func TestEffects_Apply(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - SetState[bool](&agent, 2, false) - - effects := Effects{ - Effect[int]{Key: 1, Value: 50, Operator: ADD}, - EffectBool{Key: 2, Value: true, Operator: SET}, - } - - newData, err := effects.apply(agent.w) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(newData) != 2 { - t.Errorf("Expected 2 world, got %d", len(newData)) - } - - // Check int value - intIdx := newData.GetIndex(1) - if intIdx < 0 { - t.Error("Expected to find key 1") - } - if newData[intIdx].(State[int]).Value != 150 { - t.Errorf("Expected int value 150, got %d", newData[intIdx].(State[int]).Value) - } - - // Check bool value - boolIdx := newData.GetIndex(2) - if boolIdx < 0 { - t.Error("Expected to find key 2") - } - if newData[boolIdx].(State[bool]).Value != true { - t.Error("Expected bool value true") - } -} - -func TestEffects_Apply_Error(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - - effects := Effects{ - Effect[float64]{Key: 1, Value: 50.0, Operator: SET}, // Type mismatch - } - - _, err := effects.apply(agent.w) - if err == nil { - t.Error("Expected error for type mismatch") - } -} diff --git a/astar_test.go b/astar_test.go index 0bfe4e1..0cc493a 100644 --- a/astar_test.go +++ b/astar_test.go @@ -75,49 +75,6 @@ func TestGetLessCostlyNodeKey_Empty(t *testing.T) { } } -// Test fetchNode -func TestFetchNode_Found(t *testing.T) { - agent1 := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent1, 1, 100) - agent1.w.hash = agent1.w.states.hashStates() - - agent2 := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent2, 1, 100) - agent2.w.hash = agent2.w.states.hashStates() - - nodes := []*node{ - {world: agent1.w}, - {world: agent2.w}, - } - - key, found := fetchNode(nodes, agent1.w) - if !found { - t.Error("Expected to find node") - } - if key != 0 { - t.Errorf("Expected key 0, got %d", key) - } -} - -func TestFetchNode_NotFound(t *testing.T) { - agent1 := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent1, 1, 100) - agent1.w.hash = agent1.w.states.hashStates() - - agent2 := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent2, 1, 200) - agent2.w.hash = agent2.w.states.hashStates() - - nodes := []*node{ - {world: agent1.w}, - } - - _, found := fetchNode(nodes, agent2.w) - if found { - t.Error("Expected not to find node") - } -} - // Test buildPlanFromNode func TestBuildPlanFromNode(t *testing.T) { action1 := &Action{name: "action1", cost: 1.0} diff --git a/state.go b/state.go index 5f35c75..e11bfe4 100644 --- a/state.go +++ b/state.go @@ -1,8 +1,6 @@ package goapai import ( - "encoding/binary" - "hash/fnv" "math" ) @@ -80,48 +78,54 @@ func (state State[T]) Store(w *world) { oldHash := state.hash state.hash = state.Hash() w.hash = updateHashIncremental(w.hash, oldHash, state.hash) - w.states[state.Key] = state + k := w.states.GetIndex(state.Key) + if k < 0 { + w.states = append(w.states, state) + } else { + w.states[k] = state + } } func (state State[T]) GetHash() uint64 { return state.hash } -// Hash returns a unique hash for this state using FNV-64a +// Hash returns a unique hash for this state using a fast multiplicative hash +// It implements a fast inline multiplicative hash +// Uses prime multipliers for good distribution without allocations func (state State[T]) Hash() uint64 { - h := fnv.New64a() + const ( + prime1 uint64 = 11400714819323198485 // Large prime for key + prime2 uint64 = 14029467366897019727 // Second prime for value + ) - // Write the key - buf := make([]byte, 2) - binary.LittleEndian.PutUint16(buf, uint16(state.Key)) - h.Write(buf) + // Start with key + hash := uint64(state.Key) * prime1 - // Write the value based on type - buf = make([]byte, 8) + // Mix in value based on type switch v := any(state.Value).(type) { case int8: - binary.LittleEndian.PutUint64(buf, uint64(v)) + hash ^= uint64(v) * prime2 case int: - binary.LittleEndian.PutUint64(buf, uint64(v)) + hash ^= uint64(v) * prime2 case uint8: - binary.LittleEndian.PutUint64(buf, uint64(v)) + hash ^= uint64(v) * prime2 case uint64: - binary.LittleEndian.PutUint64(buf, v) + hash ^= v * prime2 case float64: - binary.LittleEndian.PutUint64(buf, math.Float64bits(v)) + hash ^= math.Float64bits(v) * prime2 case bool: if v { - binary.LittleEndian.PutUint64(buf, 1) - } else { - binary.LittleEndian.PutUint64(buf, 0) + hash ^= prime2 } case string: - h.Write([]byte(v)) - return h.Sum64() + // For strings, hash each byte + for i := 0; i < len(v); i++ { + hash = hash*prime2 ^ uint64(v[i]) + } } - h.Write(buf) - return h.Sum64() + return hash } // Check compares world and states2 by their hash. diff --git a/state_test.go b/state_test.go index 219ffbe..c024f8f 100644 --- a/state_test.go +++ b/state_test.go @@ -69,159 +69,6 @@ func TestState_Operations(t *testing.T) { } } -func TestStates_Check(t *testing.T) { - tests := []struct { - name string - setup1 func(*Agent) - setup2 func(*Agent) - wantMatch bool - }{ - { - name: "matching world", - setup1: func(a *Agent) { - SetState[int](a, 1, 100) - }, - setup2: func(a *Agent) { - SetState[int](a, 1, 100) - }, - wantMatch: true, - }, - { - name: "different world", - setup1: func(a *Agent) { - SetState[int](a, 1, 100) - }, - setup2: func(a *Agent) { - SetState[int](a, 1, 100) - SetState[int](a, 2, 200) - }, - wantMatch: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - agent1 := CreateAgent(Goals{}, Actions{}) - tt.setup1(&agent1) - agent1.w.hash = agent1.w.states.hashStates() - - agent2 := CreateAgent(Goals{}, Actions{}) - tt.setup2(&agent2) - agent2.w.hash = agent2.w.states.hashStates() - - if got := agent1.w.Check(agent2.w); got != tt.wantMatch { - t.Errorf("Check() = %v, want %v", got, tt.wantMatch) - } - }) - } -} - -func TestStatesData_Operations(t *testing.T) { - tests := []struct { - name string - testFunc func(*testing.T) - }{ - { - name: "GetIndex found", - testFunc: func(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - State[bool]{Key: 3, Value: true}, - } - - idx := data.GetIndex(2) - if idx != 1 { - t.Errorf("Expected index 1, got %d", idx) - } - }, - }, - { - name: "GetIndex not found", - testFunc: func(t *testing.T) { - data := states{ - State[int]{Key: 1, Value: 100}, - } - - idx := data.GetIndex(99) - if idx != -1 { - t.Errorf("Expected index -1 for missing key, got %d", idx) - } - }, - }, - { - name: "hashStates same w", - testFunc: func(t *testing.T) { - data1 := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - } - - data2 := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - } - - hash1 := data1.hashStates() - hash2 := data2.hashStates() - - if hash1 != hash2 { - t.Error("Expected identical w to produce same hash") - } - }, - }, - { - name: "hashStates same w, different keys", - testFunc: func(t *testing.T) { - data1 := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - State[int]{Key: 3, Value: 300}, - } - data2 := states{ - State[int]{Key: 3, Value: 300}, - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - } - - hash1 := data1.hashStates() - hash2 := data2.hashStates() - - if hash1 != hash2 { - t.Error("Expected identical w to produce same hash") - } - }, - }, - { - name: "hashStates different w", - testFunc: func(t *testing.T) { - data1 := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - } - - data2 := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 999}, - } - - hash1 := data1.hashStates() - hash2 := data2.hashStates() - - if hash1 == hash2 { - t.Error("Expected different w to produce different hash") - } - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.testFunc(t) - }) - } -} - func TestCondition_Operators(t *testing.T) { tests := []struct { name string @@ -527,124 +374,3 @@ func TestState_Hash(t *testing.T) { }) } } - -func TestUpdateHashIncremental(t *testing.T) { - tests := []struct { - name string - currentHash uint64 - oldState StateInterface - newState StateInterface - verify func(*testing.T, uint64) - }{ - { - name: "add new state", - currentHash: 0, - oldState: nil, - newState: State[int]{Key: 1, Value: 100}, - verify: func(t *testing.T, result uint64) { - expected := State[int]{Key: 1, Value: 100}.Hash() - if result != expected { - t.Errorf("Expected hash %d, got %d", expected, result) - } - }, - }, - { - name: "remove state", - currentHash: State[int]{Key: 1, Value: 100}.Hash(), - oldState: State[int]{Key: 1, Value: 100}, - newState: nil, - verify: func(t *testing.T, result uint64) { - if result != 0 { - t.Errorf("Expected hash 0, got %d", result) - } - }, - }, - { - name: "replace state", - currentHash: State[int]{Key: 1, Value: 100}.Hash(), - oldState: State[int]{Key: 1, Value: 100}, - newState: State[int]{Key: 1, Value: 200}, - verify: func(t *testing.T, result uint64) { - expected := State[int]{Key: 1, Value: 200}.Hash() - if result != expected { - t.Errorf("Expected hash %d, got %d", expected, result) - } - }, - }, - { - name: "XOR property - multiple world", - currentHash: func() uint64 { - var h uint64 - h ^= State[int]{Key: 1, Value: 100}.Hash() - h ^= State[int]{Key: 2, Value: 200}.Hash() - return h - }(), - oldState: State[int]{Key: 1, Value: 100}, - newState: State[int]{Key: 1, Value: 150}, - verify: func(t *testing.T, result uint64) { - expected := State[int]{Key: 1, Value: 150}.Hash() ^ State[int]{Key: 2, Value: 200}.Hash() - if result != expected { - t.Errorf("Expected hash %d, got %d", expected, result) - } - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := updateHashIncremental(tt.currentHash, tt.oldState, tt.newState) - tt.verify(t, result) - }) - } -} - -func TestStatesData_HashStates_XORProperty(t *testing.T) { - // Test that hash order doesn't matter (XOR is commutative) - data1 := states{ - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - State[bool]{Key: 3, Value: true}, - } - - data2 := states{ - State[bool]{Key: 3, Value: true}, - State[int]{Key: 1, Value: 100}, - State[int]{Key: 2, Value: 200}, - } - - hash1 := data1.hashStates() - hash2 := data2.hashStates() - - if hash1 != hash2 { - t.Error("Expected XOR hash to be order-independent") - } -} - -func TestIncrementalHashConsistency(t *testing.T) { - // Verify that incremental hash matches full recalculation - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - SetState[int](&agent, 2, 200) - SetState[bool](&agent, 3, true) - agent.w.hash = agent.w.states.hashStates() - - initialHash := agent.w.hash - - // Modify a state and calculate hash incrementally - oldState := agent.w.states[0] - newState := State[int]{Key: 1, Value: 150} - - incrementalHash := updateHashIncremental(agent.w.hash, oldState, newState) - - // Modify the same state and calculate hash from scratch - agent.w.states[0] = newState - fullHash := agent.w.states.hashStates() - - if incrementalHash != fullHash { - t.Errorf("Incremental hash (%d) doesn't match full recalculation (%d)", incrementalHash, fullHash) - } - - if incrementalHash == initialHash { - t.Error("Hash should change after state modification") - } -} From 60c2dc35d7d37723a93a66ccdbf9b30b8d01811b Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 11:29:22 +0100 Subject: [PATCH 08/18] feat: use heap for open & closed nodes --- astar.go | 143 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 76 insertions(+), 67 deletions(-) diff --git a/astar.go b/astar.go index 87d7b8a..a71af97 100644 --- a/astar.go +++ b/astar.go @@ -1,8 +1,8 @@ package goapai import ( + "container/heap" "slices" - "sync" ) type node struct { @@ -14,25 +14,48 @@ type node struct { totalCost float32 heuristic float32 depth uint16 + heapIndex int // Index in the heap, needed for heap.Fix } -var nodesPool = sync.Pool{ - New: func() any { - return make([]*node, 0, 32) - }, +// nodeHeap implements heap.Interface for a min-heap of nodes based on totalCost +type nodeHeap []*node + +func (h nodeHeap) Len() int { return len(h) } + +func (h nodeHeap) Less(i, j int) bool { + return h[i].totalCost < h[j].totalCost +} + +func (h nodeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].heapIndex = i + h[j].heapIndex = j +} + +func (h *nodeHeap) Push(x interface{}) { + n := x.(*node) + n.heapIndex = len(*h) + *h = append(*h, n) +} + +func (h *nodeHeap) Pop() interface{} { + old := *h + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.heapIndex = -1 // mark as removed + *h = old[0 : n-1] + return item } func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { availableActions := getImpactingActions(from, actions) - openNodes := nodesPool.Get().([]*node) - closedNodes := nodesPool.Get().([]*node) - - defer func() { - nodesPool.Put(openNodes[:0]) - nodesPool.Put(closedNodes[:0]) - }() + openNodes := make(map[uint64]*node) + openNodesHeap := &nodeHeap{} + heap.Init(openNodesHeap) + closedNodes := make(map[uint64]*node) - openNodes = append(openNodes, &node{ + startNode := &node{ Action: &Action{}, world: world{ Agent: from.Agent, @@ -44,13 +67,17 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { totalCost: 0, heuristic: 0, depth: 0, - }) + heapIndex: -1, + } + openNodes[startNode.world.hash] = startNode + heap.Push(openNodesHeap, startNode) + + for openNodesHeap.Len() > 0 { + parentNode := heap.Pop(openNodesHeap).(*node) + delete(openNodes, parentNode.world.hash) - for openNodeKey := 0; openNodeKey != -1; openNodeKey = getLessCostlyNodeKey(openNodes) { - parentNode := openNodes[openNodeKey] if parentNode.depth > uint16(maxDepth) { - openNodes = append(openNodes[:openNodeKey], openNodes[openNodeKey+1:]...) - closedNodes = append(closedNodes, parentNode) + closedNodes[parentNode.world.hash] = parentNode continue } @@ -73,34 +100,36 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { continue } - if nodeKey, found := fetchNode(openNodes, simulatedStates); found { - node := openNodes[nodeKey] - if (parentNode.cost + action.cost) < node.cost { - node.Action = action - node.world = simulatedStates - node.parentNode = parentNode - node.cost = parentNode.cost + action.cost - node.totalCost = parentNode.cost + action.cost + node.heuristic - node.depth = parentNode.depth + 1 - - openNodes[nodeKey] = node + if currentNode, found := openNodes[simulatedStates.hash]; found { + if (parentNode.cost + action.cost) < currentNode.cost { + currentNode.Action = action + currentNode.world = simulatedStates + currentNode.parentNode = parentNode + currentNode.cost = parentNode.cost + action.cost + currentNode.totalCost = parentNode.cost + action.cost + currentNode.heuristic + currentNode.depth = parentNode.depth + 1 + + // Fix heap position after cost update + heap.Fix(openNodesHeap, currentNode.heapIndex) } - } else if nodeKey, found := fetchNode(closedNodes, simulatedStates); found { - node := closedNodes[nodeKey] - if (parentNode.cost + action.cost) < node.cost { - node.Action = action - node.world = simulatedStates - node.parentNode = parentNode - node.cost = parentNode.cost + action.cost - node.totalCost = parentNode.cost + action.cost + node.heuristic - node.depth = parentNode.depth + 1 - - openNodes[openNodeKey] = node - closedNodes = append(closedNodes[:nodeKey], closedNodes[nodeKey+1:]...) + } else if currentNode, found := closedNodes[simulatedStates.hash]; found { + if (parentNode.cost + action.cost) < currentNode.cost { + currentNode.Action = action + currentNode.world = simulatedStates + currentNode.parentNode = parentNode + currentNode.cost = parentNode.cost + action.cost + currentNode.totalCost = parentNode.cost + action.cost + currentNode.heuristic + currentNode.depth = parentNode.depth + 1 + + openNodes[simulatedStates.hash] = currentNode + delete(closedNodes, simulatedStates.hash) + + // Re-add to heap + heap.Push(openNodesHeap, currentNode) } } else { heuristic := computeHeuristic(from, goal, simulatedStates) - openNodes = append(openNodes, &node{ + newNode := &node{ Action: action, world: simulatedStates, parentNode: parentNode, @@ -108,12 +137,14 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { totalCost: parentNode.cost + action.cost + heuristic, heuristic: heuristic, depth: parentNode.depth + 1, - }) + heapIndex: -1, + } + openNodes[simulatedStates.hash] = newNode + heap.Push(openNodesHeap, newNode) } } - openNodes = append(openNodes[:openNodeKey], openNodes[openNodeKey+1:]...) - closedNodes = append(closedNodes, parentNode) + closedNodes[parentNode.world.hash] = parentNode } return Plan{} @@ -133,28 +164,6 @@ func getImpactingActions(from world, actions Actions) Actions { return availableActions } -func getLessCostlyNodeKey(openNodes []*node) int { - lowestKey := -1 - - for key, node := range openNodes { - if lowestKey < 0 || node.totalCost < openNodes[lowestKey].totalCost { - lowestKey = key - } - } - - return lowestKey -} - -func fetchNode(nodes []*node, w world) (int, bool) { - for k, n := range nodes { - if n.world.Check(w) { - return k, true - } - } - - return 0, false -} - func buildPlanFromNode(node *node) Plan { plan := make(Plan, 0, node.depth) From 32ed3aa41b0efcb670d5295418c160d904e37c32 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 11:52:31 +0100 Subject: [PATCH 09/18] feat: use heap for open & closed nodes --- astar.go | 63 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/astar.go b/astar.go index a71af97..51da02a 100644 --- a/astar.go +++ b/astar.go @@ -14,7 +14,8 @@ type node struct { totalCost float32 heuristic float32 depth uint16 - heapIndex int // Index in the heap, needed for heap.Fix + heapIndex int // Index in the heap, needed for heap.Fix + closed bool // true = closed node, false = open node } // nodeHeap implements heap.Interface for a min-heap of nodes based on totalCost @@ -50,10 +51,8 @@ func (h *nodeHeap) Pop() interface{} { func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { availableActions := getImpactingActions(from, actions) - openNodes := make(map[uint64]*node) - openNodesHeap := &nodeHeap{} - heap.Init(openNodesHeap) - closedNodes := make(map[uint64]*node) + nodesHeap := &nodeHeap{} + heap.Init(nodesHeap) startNode := &node{ Action: &Action{}, @@ -63,21 +62,23 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { hash: from.hash, }, parentNode: nil, - cost: 0, - totalCost: 0, - heuristic: 0, - depth: 0, heapIndex: -1, + closed: false, } - openNodes[startNode.world.hash] = startNode - heap.Push(openNodesHeap, startNode) + heap.Push(nodesHeap, startNode) - for openNodesHeap.Len() > 0 { - parentNode := heap.Pop(openNodesHeap).(*node) - delete(openNodes, parentNode.world.hash) + for nodesHeap.Len() > 0 { + parentNode := heap.Pop(nodesHeap).(*node) + + // Skip if already closed (lazy deletion) + if parentNode.closed { + continue + } + + // Mark as closed + parentNode.closed = true if parentNode.depth > uint16(maxDepth) { - closedNodes[parentNode.world.hash] = parentNode continue } @@ -100,7 +101,8 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { continue } - if currentNode, found := openNodes[simulatedStates.hash]; found { + // Check if node exists in open nodes (closed=false) + if currentNode, found := fetchNodeInHeap(nodesHeap, simulatedStates, false); found { if (parentNode.cost + action.cost) < currentNode.cost { currentNode.Action = action currentNode.world = simulatedStates @@ -110,9 +112,10 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { currentNode.depth = parentNode.depth + 1 // Fix heap position after cost update - heap.Fix(openNodesHeap, currentNode.heapIndex) + heap.Fix(nodesHeap, currentNode.heapIndex) } - } else if currentNode, found := closedNodes[simulatedStates.hash]; found { + } else if currentNode, found := fetchNodeInHeap(nodesHeap, simulatedStates, true); found { + // Node was closed, reopen it with better cost if (parentNode.cost + action.cost) < currentNode.cost { currentNode.Action = action currentNode.world = simulatedStates @@ -120,14 +123,13 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { currentNode.cost = parentNode.cost + action.cost currentNode.totalCost = parentNode.cost + action.cost + currentNode.heuristic currentNode.depth = parentNode.depth + 1 + currentNode.closed = false // Reopen - openNodes[simulatedStates.hash] = currentNode - delete(closedNodes, simulatedStates.hash) - - // Re-add to heap - heap.Push(openNodesHeap, currentNode) + // Fix heap position + heap.Fix(nodesHeap, currentNode.heapIndex) } } else { + // New node heuristic := computeHeuristic(from, goal, simulatedStates) newNode := &node{ Action: action, @@ -138,13 +140,11 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { heuristic: heuristic, depth: parentNode.depth + 1, heapIndex: -1, + closed: false, } - openNodes[simulatedStates.hash] = newNode - heap.Push(openNodesHeap, newNode) + heap.Push(nodesHeap, newNode) } } - - closedNodes[parentNode.world.hash] = parentNode } return Plan{} @@ -164,6 +164,15 @@ func getImpactingActions(from world, actions Actions) Actions { return availableActions } +func fetchNodeInHeap(heap *nodeHeap, w world, closed bool) (*node, bool) { + for _, n := range *heap { + if n.closed == closed && n.world.Check(w) { + return n, true + } + } + return nil, false +} + func buildPlanFromNode(node *node) Plan { plan := make(Plan, 0, node.depth) From 2d6f7a33d126a14daaaa1cb22186a79bf6f24ac0 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 15:41:23 +0100 Subject: [PATCH 10/18] feat: heap file & new heuristic --- astar.go | 97 +++++++++++++++++++++------------------------------- heap.go | 32 ++++++++++++++++++ state.go | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 58 deletions(-) create mode 100644 heap.go diff --git a/astar.go b/astar.go index 51da02a..310df65 100644 --- a/astar.go +++ b/astar.go @@ -18,41 +18,8 @@ type node struct { closed bool // true = closed node, false = open node } -// nodeHeap implements heap.Interface for a min-heap of nodes based on totalCost -type nodeHeap []*node - -func (h nodeHeap) Len() int { return len(h) } - -func (h nodeHeap) Less(i, j int) bool { - return h[i].totalCost < h[j].totalCost -} - -func (h nodeHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] - h[i].heapIndex = i - h[j].heapIndex = j -} - -func (h *nodeHeap) Push(x interface{}) { - n := x.(*node) - n.heapIndex = len(*h) - *h = append(*h, n) -} - -func (h *nodeHeap) Pop() interface{} { - old := *h - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - item.heapIndex = -1 // mark as removed - *h = old[0 : n-1] - return item -} - func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { availableActions := getImpactingActions(from, actions) - nodesHeap := &nodeHeap{} - heap.Init(nodesHeap) startNode := &node{ Action: &Action{}, @@ -65,20 +32,17 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { heapIndex: -1, closed: false, } - heap.Push(nodesHeap, startNode) - for nodesHeap.Len() > 0 { - parentNode := heap.Pop(nodesHeap).(*node) + nodesHeap := nodeHeap{} + heap.Init(&nodesHeap) + heap.Push(&nodesHeap, startNode) - // Skip if already closed (lazy deletion) - if parentNode.closed { - continue - } - - // Mark as closed - parentNode.closed = true + for nodesHeap.Len() > 0 { + parentNode := heap.Pop(&nodesHeap).(*node) if parentNode.depth > uint16(maxDepth) { + parentNode.closed = true + heap.Fix(&nodesHeap, parentNode.heapIndex) continue } @@ -101,8 +65,9 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { continue } + currentNode, found := fetchNodeInHeap(nodesHeap, simulatedStates) // Check if node exists in open nodes (closed=false) - if currentNode, found := fetchNodeInHeap(nodesHeap, simulatedStates, false); found { + if found && !currentNode.closed { if (parentNode.cost + action.cost) < currentNode.cost { currentNode.Action = action currentNode.world = simulatedStates @@ -112,9 +77,9 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { currentNode.depth = parentNode.depth + 1 // Fix heap position after cost update - heap.Fix(nodesHeap, currentNode.heapIndex) + heap.Fix(&nodesHeap, currentNode.heapIndex) } - } else if currentNode, found := fetchNodeInHeap(nodesHeap, simulatedStates, true); found { + } else if found && currentNode.closed { // Node was closed, reopen it with better cost if (parentNode.cost + action.cost) < currentNode.cost { currentNode.Action = action @@ -126,7 +91,7 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { currentNode.closed = false // Reopen // Fix heap position - heap.Fix(nodesHeap, currentNode.heapIndex) + heap.Fix(&nodesHeap, currentNode.heapIndex) } } else { // New node @@ -142,7 +107,7 @@ func astar(from world, goal goalInterface, actions Actions, maxDepth int) Plan { heapIndex: -1, closed: false, } - heap.Push(nodesHeap, newNode) + heap.Push(&nodesHeap, newNode) } } } @@ -164,9 +129,9 @@ func getImpactingActions(from world, actions Actions) Actions { return availableActions } -func fetchNodeInHeap(heap *nodeHeap, w world, closed bool) (*node, bool) { - for _, n := range *heap { - if n.closed == closed && n.world.Check(w) { +func fetchNodeInHeap(heap nodeHeap, w world) (*node, bool) { + for _, n := range heap { + if n.world.Check(w) { return n, true } } @@ -231,15 +196,31 @@ func countMissingGoal(goal goalInterface, w world) int { } /* -A very simple (empiristic) model for h using: - - how much required world are met - -We try to be conservative and reduce the number of steps +Improved heuristic using numeric distance calculation: + - For each goal condition, calculate the numeric distance between current value and target + - Sum all distances to get total heuristic + - This provides much better guidance than simple binary satisfied/unsatisfied check */ func computeHeuristic(from world, goal goalInterface, w world) float32 { - missingGoalsCount := float32(countMissingGoal(goal, w)) + var totalDistance float32 - h := missingGoalsCount + for _, condition := range goal.Conditions { + key := condition.GetKey() + stateIndex := w.states.GetIndex(key) + + if stateIndex >= 0 { + // State exists, calculate actual distance + state := w.states[stateIndex] + distance := state.Distance(condition) + totalDistance += distance + } else { + // State doesn't exist, use pessimistic estimate + // If the condition is not satisfied and state doesn't exist, assume distance of 1 + if !condition.Check(w) { + totalDistance += 1.0 + } + } + } - return h + return totalDistance } diff --git a/heap.go b/heap.go new file mode 100644 index 0000000..012805c --- /dev/null +++ b/heap.go @@ -0,0 +1,32 @@ +package goapai + +// nodeHeap implements heap.Interface for a min-heap of nodes based on totalCost +type nodeHeap []*node + +func (h nodeHeap) Len() int { return len(h) } + +func (h nodeHeap) Less(i, j int) bool { + return h[i].totalCost < h[j].totalCost +} + +func (h nodeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].heapIndex = i + h[j].heapIndex = j +} + +func (h *nodeHeap) Push(x interface{}) { + n := x.(*node) + n.heapIndex = len(*h) + *h = append(*h, n) +} + +func (h *nodeHeap) Pop() interface{} { + old := *h + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.heapIndex = -1 // mark as removed + *h = old[0 : n-1] + return item +} diff --git a/state.go b/state.go index e11bfe4..a462028 100644 --- a/state.go +++ b/state.go @@ -28,6 +28,7 @@ type StateInterface interface { Store(w *world) GetHash() uint64 Hash() uint64 + Distance(condition ConditionInterface) float32 } type State[T Numeric | bool | string] struct { Key StateKey @@ -128,6 +129,106 @@ func (state State[T]) Hash() uint64 { return hash } +// Distance calculates the distance between the current state value and a condition target value +// Returns 0 if the condition is already satisfied, otherwise returns the numeric distance +func (state State[T]) Distance(condition ConditionInterface) float32 { + // Check if the condition key matches + if state.Key != condition.GetKey() { + return 0 + } + + // Handle different condition types + switch cond := condition.(type) { + case *Condition[int8]: + if v, ok := any(state.Value).(int8); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[int]: + if v, ok := any(state.Value).(int); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[uint8]: + if v, ok := any(state.Value).(uint8); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[uint64]: + if v, ok := any(state.Value).(uint64); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[float64]: + if v, ok := any(state.Value).(float64); ok { + return calculateNumericDistance(v, cond.Value, cond.Operator) + } + case *ConditionBool: + if v, ok := any(state.Value).(bool); ok { + if cond.Operator == EQUAL { + if v == cond.Value { + return 0 + } + return 1 + } else if cond.Operator == NOT_EQUAL { + if v != cond.Value { + return 0 + } + return 1 + } + } + case *ConditionString: + if v, ok := any(state.Value).(string); ok { + if cond.Operator == EQUAL { + if v == cond.Value { + return 0 + } + return 1 + } else if cond.Operator == NOT_EQUAL { + if v != cond.Value { + return 0 + } + return 1 + } + } + } + + return 0 +} + +// calculateNumericDistance computes the distance for numeric conditions based on operator +func calculateNumericDistance(current, target float64, op operator) float32 { + switch op { + case EQUAL: + if current < target { + return float32(target - current) + } + return float32(current - target) + case NOT_EQUAL: + if current == target { + return 1.0 + } + return 0.0 + case UPPER_OR_EQUAL: + if current < target { + return float32(target - current) + } + return 0.0 + case UPPER: + if current <= target { + return float32(target - current + 1) + } + return 0.0 + case LOWER_OR_EQUAL: + if current > target { + return float32(current - target) + } + return 0.0 + case LOWER: + if current >= target { + return float32(current - target + 1) + } + return 0.0 + } + return 0.0 +} + // Check compares world and states2 by their hash. func (world world) Check(world2 world) bool { return world.hash == world2.hash From c4596a1523f703ca2942db584b934cc136837c9c Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 15:44:39 +0100 Subject: [PATCH 11/18] chore: remove todo in readme --- README.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/README.md b/README.md index 54f979c..c5b4ad8 100644 --- a/README.md +++ b/README.md @@ -174,10 +174,7 @@ GOAP needs to be benchmarked and monitored regularly because of exponential risk Though if well scoped you can manage hundred of Actions for 200µs per Agent. ## What's next? -- The simulateActionState functions, to check the effect of an action on a node, takes up to 40% of CPU and 40% of memory. -We need to refactorize this part, or find another logical path. -- Heuristic calculation in A* is done poorly, we need a better algorithm to improve performances. -- Benchmark a backward implementation like D*, to improve performances. +- Benchmark a backward implementation like D*. It might improve performances. ## Sources - https://web.archive.org/web/20230912145018/http://alumni.media.mit.edu/~jorkin/goap.html From 260b056f3044beb6a5b111288e20491c298ca8cb Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 15:53:40 +0100 Subject: [PATCH 12/18] feat: clean unit tests --- action_test.go | 202 +++++++++++++++++++++++++++++++++---------------- agent_test.go | 63 +++++++++------ astar_test.go | 23 ------ 3 files changed, 179 insertions(+), 109 deletions(-) diff --git a/action_test.go b/action_test.go index c02a970..9f8485f 100644 --- a/action_test.go +++ b/action_test.go @@ -2,92 +2,166 @@ package goapai import "testing" -// Test Actions.AddAction func TestActions_AddAction(t *testing.T) { - actions := Actions{} - - actions.AddAction("test", 1.5, false, Conditions{}, Effects{}) - - if len(actions) != 1 { - t.Errorf("Expected 1 action, got %d", len(actions)) + tests := []struct { + name string + actionsToAdd []struct{ name string; cost float32; repeatable bool } + wantCount int + checkFirst bool + wantFirstName string + wantFirstCost float32 + wantRepeatble bool + }{ + { + name: "single action", + actionsToAdd: []struct{ name string; cost float32; repeatable bool }{ + {name: "test", cost: 1.5, repeatable: false}, + }, + wantCount: 1, + checkFirst: true, + wantFirstName: "test", + wantFirstCost: 1.5, + wantRepeatble: false, + }, + { + name: "multiple actions", + actionsToAdd: []struct{ name string; cost float32; repeatable bool }{ + {name: "action1", cost: 1.0, repeatable: true}, + {name: "action2", cost: 2.0, repeatable: false}, + }, + wantCount: 2, + checkFirst: false, + }, } - action := actions[0] - if action.name != "test" { - t.Errorf("Expected name 'test', got '%s'", action.name) - } - if action.cost != 1.5 { - t.Errorf("Expected cost 1.5, got %f", action.cost) - } - if action.repeatable { - t.Error("Expected repeatable to be false") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actions := Actions{} + + for _, a := range tt.actionsToAdd { + actions.AddAction(a.name, a.cost, a.repeatable, Conditions{}, Effects{}) + } + + if len(actions) != tt.wantCount { + t.Errorf("Expected %d actions, got %d", tt.wantCount, len(actions)) + } + + if tt.checkFirst { + action := actions[0] + if action.name != tt.wantFirstName { + t.Errorf("Expected name '%s', got '%s'", tt.wantFirstName, action.name) + } + if action.cost != tt.wantFirstCost { + t.Errorf("Expected cost %f, got %f", tt.wantFirstCost, action.cost) + } + if action.repeatable != tt.wantRepeatble { + t.Errorf("Expected repeatable to be %v", tt.wantRepeatble) + } + } + }) } } -func TestActions_AddAction_Multiple(t *testing.T) { - actions := Actions{} - - actions.AddAction("action1", 1.0, true, Conditions{}, Effects{}) - actions.AddAction("action2", 2.0, false, Conditions{}, Effects{}) - - if len(actions) != 2 { - t.Errorf("Expected 2 actions, got %d", len(actions)) +func TestAction_GetName(t *testing.T) { + tests := []struct { + name string + actionName string + want string + }{ + { + name: "basic name", + actionName: "my_action", + want: "my_action", + }, } -} -// Test Action.GetName -func TestAction_GetName(t *testing.T) { - actions := Actions{} - actions.AddAction("my_action", 1.0, false, Conditions{}, Effects{}) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actions := Actions{} + actions.AddAction(tt.actionName, 1.0, false, Conditions{}, Effects{}) - name := actions[0].GetName() - if name != "my_action" { - t.Errorf("Expected 'my_action', got '%s'", name) + got := actions[0].GetName() + if got != tt.want { + t.Errorf("Expected '%s', got '%s'", tt.want, got) + } + }) } } -// Test Action.GetEffects func TestAction_GetEffects(t *testing.T) { - effects := Effects{ - Effect[int]{Key: 1, Value: 10, Operator: SET}, + tests := []struct { + name string + effects Effects + wantCount int + }{ + { + name: "single effect", + effects: Effects{ + Effect[int]{Key: 1, Value: 10, Operator: SET}, + }, + wantCount: 1, + }, } - actions := Actions{} - actions.AddAction("test", 1.0, false, Conditions{}, effects) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actions := Actions{} + actions.AddAction("test", 1.0, false, Conditions{}, tt.effects) - retrieved := actions[0].GetEffects() - if len(retrieved) != 1 { - t.Errorf("Expected 1 effect, got %d", len(retrieved)) + retrieved := actions[0].GetEffects() + if len(retrieved) != tt.wantCount { + t.Errorf("Expected %d effect(s), got %d", tt.wantCount, len(retrieved)) + } + }) } } -// Test Effect[T Numeric] operations -func TestEffect_Check_Match(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - - effect := Effect[int]{Key: 1, Value: 100, Operator: SET} - if !effect.check(agent.w) { - t.Error("Expected effect to match state") +func TestEffect_Check(t *testing.T) { + tests := []struct { + name string + stateValue int + effectKey StateKey + effectVal int + operator arithmetic + want bool + }{ + { + name: "match with SET operator", + stateValue: 100, + effectKey: 1, + effectVal: 100, + operator: SET, + want: true, + }, + { + name: "no match with SET operator", + stateValue: 100, + effectKey: 1, + effectVal: 200, + operator: SET, + want: false, + }, + { + name: "non-SET operator always false", + stateValue: 100, + effectKey: 1, + effectVal: 100, + operator: ADD, + want: false, + }, } -} -func TestEffect_Check_NoMatch(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) - - effect := Effect[int]{Key: 1, Value: 200, Operator: SET} - if effect.check(agent.w) { - t.Error("Expected effect not to match state") - } -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[int](&agent, tt.effectKey, tt.stateValue) -func TestEffect_Check_NonSetOperator(t *testing.T) { - agent := CreateAgent(Goals{}, Actions{}) - SetState[int](&agent, 1, 100) + effect := Effect[int]{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} + got := effect.check(agent.w) - effect := Effect[int]{Key: 1, Value: 100, Operator: ADD} - if effect.check(agent.w) { - t.Error("Expected non-SET operator to always return false") + if got != tt.want { + t.Errorf("Expected %v, got %v", tt.want, got) + } + }) } } diff --git a/agent_test.go b/agent_test.go index 808bc2f..30810de 100644 --- a/agent_test.go +++ b/agent_test.go @@ -3,36 +3,55 @@ package goapai import "testing" func TestCreateAgent(t *testing.T) { - goals := Goals{ - "test_goal": { - Conditions: Conditions{ - &ConditionBool{Key: 1, Value: true}, - }, - PriorityFn: func(sensors Sensors) float32 { - return 1.0 + tests := []struct { + name string + goals Goals + actions Actions + wantActionCnt int + wantGoalCnt int + }{ + { + name: "agent with goal and action", + goals: Goals{ + "test_goal": { + Conditions: Conditions{ + &ConditionBool{Key: 1, Value: true}, + }, + PriorityFn: func(sensors Sensors) float32 { + return 1.0 + }, + }, }, + actions: func() Actions { + actions := Actions{} + actions.AddAction("test_action", 1.0, false, Conditions{}, Effects{}) + return actions + }(), + wantActionCnt: 1, + wantGoalCnt: 1, }, } - actions := Actions{} - actions.AddAction("test_action", 1.0, false, Conditions{}, Effects{}) - - agent := CreateAgent(goals, actions) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(tt.goals, tt.actions) - if len(agent.actions) != 1 { - t.Errorf("Expected 1 action, got %d", len(agent.actions)) - } + if len(agent.actions) != tt.wantActionCnt { + t.Errorf("Expected %d action(s), got %d", tt.wantActionCnt, len(agent.actions)) + } - if len(agent.goals) != 1 { - t.Errorf("Expected 1 goal, got %d", len(agent.goals)) - } + if len(agent.goals) != tt.wantGoalCnt { + t.Errorf("Expected %d goal(s), got %d", tt.wantGoalCnt, len(agent.goals)) + } - if agent.sensors == nil { - t.Error("Expected sensors to be initialized") - } + if agent.sensors == nil { + t.Error("Expected sensors to be initialized") + } - if agent.w.Agent == nil { - t.Error("Expected world.Agent to be non-nil") + if agent.w.Agent == nil { + t.Error("Expected world.Agent to be non-nil") + } + }) } } diff --git a/astar_test.go b/astar_test.go index 0cc493a..a42888f 100644 --- a/astar_test.go +++ b/astar_test.go @@ -52,29 +52,6 @@ func TestGetImpactingActions_AllImpacting(t *testing.T) { } } -// Test getLessCostlyNodeKey -func TestGetLessCostlyNodeKey(t *testing.T) { - nodes := []*node{ - {totalCost: 10.0}, - {totalCost: 5.0}, - {totalCost: 15.0}, - } - - key := getLessCostlyNodeKey(nodes) - if key != 1 { - t.Errorf("Expected key 1 (lowest cost), got %d", key) - } -} - -func TestGetLessCostlyNodeKey_Empty(t *testing.T) { - nodes := []*node{} - - key := getLessCostlyNodeKey(nodes) - if key != -1 { - t.Errorf("Expected -1 for empty list, got %d", key) - } -} - // Test buildPlanFromNode func TestBuildPlanFromNode(t *testing.T) { action1 := &Action{name: "action1", cost: 1.0} From dabd6633bb3424afb6f0a0758f54163f51842398 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 16:04:22 +0100 Subject: [PATCH 13/18] fix: set key in effect.apply --- action.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/action.go b/action.go index 38fd5eb..89406a2 100644 --- a/action.go +++ b/action.go @@ -83,10 +83,10 @@ func (effect Effect[T]) apply(w *world) error { k := w.states.GetIndex(effect.Key) if k < 0 { if slices.Contains([]arithmetic{SET, ADD}, effect.Operator) { - w.states = append(w.states, State[T]{Value: effect.Value}) + w.states = append(w.states, State[T]{Key: effect.Key, Value: effect.Value}) return nil } else if slices.Contains([]arithmetic{SUBSTRACT}, effect.Operator) { - w.states = append(w.states, State[T]{Value: -effect.Value}) + w.states = append(w.states, State[T]{Key: effect.Key, Value: -effect.Value}) return nil } return fmt.Errorf("w does not exist") @@ -150,7 +150,7 @@ func (effectBool EffectBool) apply(w *world) error { k := w.states.GetIndex(effectBool.Key) if k < 0 { - w.states = append(w.states, State[bool]{Value: effectBool.Value}) + w.states = append(w.states, State[bool]{Key: effectBool.Key, Value: effectBool.Value}) return nil } if _, ok := w.states[k].(State[bool]); !ok { @@ -196,7 +196,7 @@ func (effectString EffectString) apply(w *world) error { k := w.states.GetIndex(effectString.Key) if k < 0 { - w.states = append(w.states, State[string]{Value: effectString.Value}) + w.states = append(w.states, State[string]{Key: effectString.Key, Value: effectString.Value}) return nil } if _, ok := w.states[k].(State[string]); !ok { From e6fdbba4cfedb8e152146968f026b74506d8e8a9 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 16:09:09 +0100 Subject: [PATCH 14/18] chore: add unit tests --- distance_test.go | 216 ++++++++++++++++++++++++++++++ effect_test.go | 339 +++++++++++++++++++++++++++++++++++++++++++++++ heap_test.go | 221 ++++++++++++++++++++++++++++++ 3 files changed, 776 insertions(+) create mode 100644 distance_test.go create mode 100644 effect_test.go create mode 100644 heap_test.go diff --git a/distance_test.go b/distance_test.go new file mode 100644 index 0000000..fbcbafb --- /dev/null +++ b/distance_test.go @@ -0,0 +1,216 @@ +package goapai + +import "testing" + +func TestState_Distance(t *testing.T) { + tests := []struct { + name string + stateVal interface{} + condition ConditionInterface + want float32 + }{ + // Numeric types - EQUAL operator + { + name: "int EQUAL satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + want: 0.0, + }, + { + name: "int EQUAL below target", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + want: 50.0, + }, + { + name: "int EQUAL above target", + stateVal: 150, + condition: &Condition[int]{Key: 1, Value: 100, Operator: EQUAL}, + want: 50.0, + }, + // UPPER_OR_EQUAL operator + { + name: "int UPPER_OR_EQUAL satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER_OR_EQUAL}, + want: 0.0, + }, + { + name: "int UPPER_OR_EQUAL below target", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER_OR_EQUAL}, + want: 30.0, + }, + // UPPER operator + { + name: "int UPPER satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER}, + want: 0.0, + }, + { + name: "int UPPER at target", + stateVal: 80, + condition: &Condition[int]{Key: 1, Value: 80, Operator: UPPER}, + want: 1.0, + }, + // LOWER_OR_EQUAL operator + { + name: "int LOWER_OR_EQUAL satisfied", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER_OR_EQUAL}, + want: 0.0, + }, + { + name: "int LOWER_OR_EQUAL above target", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER_OR_EQUAL}, + want: 20.0, + }, + // LOWER operator + { + name: "int LOWER satisfied", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER}, + want: 0.0, + }, + { + name: "int LOWER at target", + stateVal: 80, + condition: &Condition[int]{Key: 1, Value: 80, Operator: LOWER}, + want: 1.0, + }, + // NOT_EQUAL operator + { + name: "int NOT_EQUAL satisfied", + stateVal: 50, + condition: &Condition[int]{Key: 1, Value: 100, Operator: NOT_EQUAL}, + want: 0.0, + }, + { + name: "int NOT_EQUAL not satisfied", + stateVal: 100, + condition: &Condition[int]{Key: 1, Value: 100, Operator: NOT_EQUAL}, + want: 1.0, + }, + // Float64 tests + { + name: "float64 EQUAL", + stateVal: 75.5, + condition: &Condition[float64]{Key: 1, Value: 100.0, Operator: EQUAL}, + want: 24.5, + }, + // uint64 tests + { + name: "uint64 UPPER_OR_EQUAL", + stateVal: uint64(50), + condition: &Condition[uint64]{Key: 1, Value: uint64(100), Operator: UPPER_OR_EQUAL}, + want: 50.0, + }, + // int8 tests + { + name: "int8 EQUAL", + stateVal: int8(50), + condition: &Condition[int8]{Key: 1, Value: int8(100), Operator: EQUAL}, + want: 50.0, + }, + // uint8 tests + { + name: "uint8 LOWER_OR_EQUAL", + stateVal: uint8(100), + condition: &Condition[uint8]{Key: 1, Value: uint8(80), Operator: LOWER_OR_EQUAL}, + want: 20.0, + }, + // Bool tests + { + name: "bool EQUAL satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: true, Operator: EQUAL}, + want: 0.0, + }, + { + name: "bool EQUAL not satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: false, Operator: EQUAL}, + want: 1.0, + }, + { + name: "bool NOT_EQUAL satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: false, Operator: NOT_EQUAL}, + want: 0.0, + }, + { + name: "bool NOT_EQUAL not satisfied", + stateVal: true, + condition: &ConditionBool{Key: 1, Value: true, Operator: NOT_EQUAL}, + want: 1.0, + }, + // String tests + { + name: "string EQUAL satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "test", Operator: EQUAL}, + want: 0.0, + }, + { + name: "string EQUAL not satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "other", Operator: EQUAL}, + want: 1.0, + }, + { + name: "string NOT_EQUAL satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "other", Operator: NOT_EQUAL}, + want: 0.0, + }, + { + name: "string NOT_EQUAL not satisfied", + stateVal: "test", + condition: &ConditionString{Key: 1, Value: "test", Operator: NOT_EQUAL}, + want: 1.0, + }, + // Key mismatch + { + name: "key mismatch", + stateVal: 100, + condition: &Condition[int]{Key: 99, Value: 100, Operator: EQUAL}, + want: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + + switch v := tt.stateVal.(type) { + case int: + SetState[int](&agent, 1, v) + case int8: + SetState[int8](&agent, 1, v) + case uint8: + SetState[uint8](&agent, 1, v) + case uint64: + SetState[uint64](&agent, 1, v) + case float64: + SetState[float64](&agent, 1, v) + case bool: + SetState[bool](&agent, 1, v) + case string: + SetState[string](&agent, 1, v) + } + + if len(agent.w.states) == 0 { + t.Fatal("Failed to set state") + } + + state := agent.w.states[0] + got := state.Distance(tt.condition) + + if got != tt.want { + t.Errorf("Distance() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/effect_test.go b/effect_test.go new file mode 100644 index 0000000..40d80f8 --- /dev/null +++ b/effect_test.go @@ -0,0 +1,339 @@ +package goapai + +import "testing" + +func TestEffectString_GetKey(t *testing.T) { + effect := EffectString{Key: 42, Value: "test"} + if got := effect.GetKey(); got != 42 { + t.Errorf("GetKey() = %v, want 42", got) + } +} + +func TestEffectString_Check(t *testing.T) { + tests := []struct { + name string + stateVal string + effectKey StateKey + effectVal string + operator arithmetic + want bool + }{ + { + name: "SET match", + stateVal: "hello", + effectKey: 1, + effectVal: "hello", + operator: SET, + want: true, + }, + { + name: "SET no match", + stateVal: "hello", + effectKey: 1, + effectVal: "world", + operator: SET, + want: false, + }, + { + name: "key not found", + stateVal: "hello", + effectKey: 99, + effectVal: "test", + operator: SET, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + if tt.effectKey == tt.effectKey { // Only set state if key matches + SetState[string](&agent, 1, tt.stateVal) + } + + effect := EffectString{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} + got := effect.check(agent.w) + + if got != tt.want { + t.Errorf("check() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEffectString_Apply(t *testing.T) { + tests := []struct { + name string + setupState bool + stateVal string + effectKey StateKey + effectVal string + operator arithmetic + wantVal string + wantErr bool + }{ + { + name: "SET new state", + setupState: false, + effectKey: 1, + effectVal: "hello", + operator: SET, + wantVal: "hello", + wantErr: false, + }, + { + name: "SET existing state", + setupState: true, + stateVal: "old", + effectKey: 1, + effectVal: "new", + operator: SET, + wantVal: "new", + wantErr: false, + }, + { + name: "ADD concatenation", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: " world", + operator: ADD, + wantVal: "hello world", + wantErr: false, + }, + { + name: "ADD new state", + setupState: false, + effectKey: 1, + effectVal: "test", + operator: ADD, + wantVal: "test", + wantErr: false, + }, + { + name: "SUBSTRACT not allowed", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: "test", + operator: SUBSTRACT, + wantErr: true, + }, + { + name: "MULTIPLY not allowed", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: "test", + operator: MULTIPLY, + wantErr: true, + }, + { + name: "DIVIDE not allowed", + setupState: true, + stateVal: "hello", + effectKey: 1, + effectVal: "test", + operator: DIVIDE, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + if tt.setupState { + SetState[string](&agent, tt.effectKey, tt.stateVal) + } + + effect := EffectString{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} + err := effect.apply(&agent.w) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + idx := agent.w.states.GetIndex(tt.effectKey) + if idx < 0 { + t.Fatal("State not found after apply") + } + + gotState := agent.w.states[idx].(State[string]) + if gotState.Value != tt.wantVal { + t.Errorf("Value = %v, want %v", gotState.Value, tt.wantVal) + } + }) + } +} + +func TestEffectBool_GetKey(t *testing.T) { + effect := EffectBool{Key: 42, Value: true} + if got := effect.GetKey(); got != 42 { + t.Errorf("GetKey() = %v, want 42", got) + } +} + +func TestEffectBool_Apply_Errors(t *testing.T) { + tests := []struct { + name string + operator arithmetic + wantErr bool + }{ + { + name: "SET allowed", + operator: SET, + wantErr: false, + }, + { + name: "ADD not allowed", + operator: ADD, + wantErr: true, + }, + { + name: "SUBSTRACT not allowed", + operator: SUBSTRACT, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + SetState[bool](&agent, 1, true) + + effect := EffectBool{Key: 1, Value: false, Operator: tt.operator} + err := effect.apply(&agent.w) + + if tt.wantErr && err == nil { + t.Error("Expected error, got nil") + } + if !tt.wantErr && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestEffect_Apply_AllOperators(t *testing.T) { + tests := []struct { + name string + initial int + operator arithmetic + value int + wantVal int + wantErr bool + }{ + { + name: "SET", + initial: 100, + operator: SET, + value: 50, + wantVal: 50, + }, + { + name: "ADD", + initial: 100, + operator: ADD, + value: 50, + wantVal: 150, + }, + { + name: "SUBSTRACT", + initial: 100, + operator: SUBSTRACT, + value: 30, + wantVal: 70, + }, + { + name: "MULTIPLY", + initial: 10, + operator: MULTIPLY, + value: 5, + wantVal: 50, + }, + { + name: "DIVIDE", + initial: 100, + operator: DIVIDE, + value: 4, + wantVal: 25, + }, + { + name: "ADD on non-existing key", + initial: 0, // Not set + operator: ADD, + value: 50, + wantVal: 50, + }, + { + name: "SUBSTRACT on non-existing key", + initial: 0, // Not set + operator: SUBSTRACT, + value: 50, + wantVal: -50, + }, + { + name: "MULTIPLY on non-existing key error", + initial: 0, // Not set + operator: MULTIPLY, + value: 50, + wantErr: true, + }, + { + name: "DIVIDE on non-existing key error", + initial: 0, // Not set + operator: DIVIDE, + value: 50, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + agent := CreateAgent(Goals{}, Actions{}) + if tt.initial != 0 || (tt.operator != ADD && tt.operator != SUBSTRACT && tt.operator != MULTIPLY && tt.operator != DIVIDE) { + SetState[int](&agent, 1, tt.initial) + } + + effect := Effect[int]{Key: 1, Value: tt.value, Operator: tt.operator} + err := effect.apply(&agent.w) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + idx := agent.w.states.GetIndex(1) + if idx < 0 { + t.Fatal("State not found after apply") + } + + gotState := agent.w.states[idx].(State[int]) + if gotState.Value != tt.wantVal { + t.Errorf("Value = %v, want %v", gotState.Value, tt.wantVal) + } + }) + } +} + +func TestEffect_GetKey(t *testing.T) { + effect := Effect[int]{Key: 42, Value: 100} + if got := effect.GetKey(); got != 42 { + t.Errorf("GetKey() = %v, want 42", got) + } +} diff --git a/heap_test.go b/heap_test.go new file mode 100644 index 0000000..03850b6 --- /dev/null +++ b/heap_test.go @@ -0,0 +1,221 @@ +package goapai + +import "testing" + +func TestNodeHeap_Less(t *testing.T) { + tests := []struct { + name string + nodes nodeHeap + i, j int + wantLess bool + }{ + { + name: "first node has lower cost", + nodes: nodeHeap{ + {totalCost: 5.0}, + {totalCost: 10.0}, + }, + i: 0, + j: 1, + wantLess: true, + }, + { + name: "first node has higher cost", + nodes: nodeHeap{ + {totalCost: 15.0}, + {totalCost: 10.0}, + }, + i: 0, + j: 1, + wantLess: false, + }, + { + name: "equal costs", + nodes: nodeHeap{ + {totalCost: 10.0}, + {totalCost: 10.0}, + }, + i: 0, + j: 1, + wantLess: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.nodes.Less(tt.i, tt.j) + if got != tt.wantLess { + t.Errorf("Less(%d, %d) = %v, want %v", tt.i, tt.j, got, tt.wantLess) + } + }) + } +} + +func TestNodeHeap_Swap(t *testing.T) { + tests := []struct { + name string + initialNodes []*node + i, j int + wantI, wantJ int + wantIIndex int + wantJIndex int + }{ + { + name: "swap first and second", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + }, + i: 0, + j: 1, + wantI: 1, + wantJ: 0, + wantIIndex: 0, + wantJIndex: 1, + }, + { + name: "swap first and third", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + {totalCost: 15.0, heapIndex: 2}, + }, + i: 0, + j: 2, + wantI: 2, + wantJ: 0, + wantIIndex: 0, + wantJIndex: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := nodeHeap(tt.initialNodes) + originalI := h[tt.i] + originalJ := h[tt.j] + + h.Swap(tt.i, tt.j) + + // Verify nodes were swapped + if h[tt.i] != originalJ { + t.Errorf("After swap, h[%d] should be originalJ", tt.i) + } + if h[tt.j] != originalI { + t.Errorf("After swap, h[%d] should be originalI", tt.j) + } + + // Verify heapIndex was updated + if h[tt.i].heapIndex != tt.wantIIndex { + t.Errorf("h[%d].heapIndex = %d, want %d", tt.i, h[tt.i].heapIndex, tt.wantIIndex) + } + if h[tt.j].heapIndex != tt.wantJIndex { + t.Errorf("h[%d].heapIndex = %d, want %d", tt.j, h[tt.j].heapIndex, tt.wantJIndex) + } + }) + } +} + +func TestNodeHeap_Push(t *testing.T) { + tests := []struct { + name string + initialNodes []*node + pushNode *node + wantLen int + wantIndex int + }{ + { + name: "push to empty heap", + initialNodes: []*node{}, + pushNode: &node{totalCost: 5.0}, + wantLen: 1, + wantIndex: 0, + }, + { + name: "push to non-empty heap", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + }, + pushNode: &node{totalCost: 15.0}, + wantLen: 3, + wantIndex: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := nodeHeap(tt.initialNodes) + h.Push(tt.pushNode) + + if len(h) != tt.wantLen { + t.Errorf("Len() = %d, want %d", len(h), tt.wantLen) + } + + if tt.pushNode.heapIndex != tt.wantIndex { + t.Errorf("heapIndex = %d, want %d", tt.pushNode.heapIndex, tt.wantIndex) + } + + if h[len(h)-1] != tt.pushNode { + t.Error("Pushed node should be at the end of heap") + } + }) + } +} + +func TestNodeHeap_Pop(t *testing.T) { + tests := []struct { + name string + initialNodes []*node + wantNode *node + wantLen int + }{ + { + name: "pop from heap with one element", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + }, + wantNode: &node{totalCost: 5.0, heapIndex: -1}, + wantLen: 0, + }, + { + name: "pop from heap with multiple elements", + initialNodes: []*node{ + {totalCost: 5.0, heapIndex: 0}, + {totalCost: 10.0, heapIndex: 1}, + {totalCost: 15.0, heapIndex: 2}, + }, + wantNode: &node{totalCost: 15.0, heapIndex: -1}, + wantLen: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := nodeHeap(tt.initialNodes) + originalLast := h[len(h)-1] + + popped := h.Pop().(*node) + + if len(h) != tt.wantLen { + t.Errorf("Len() = %d, want %d", len(h), tt.wantLen) + } + + if popped.heapIndex != -1 { + t.Errorf("Popped node heapIndex = %d, want -1", popped.heapIndex) + } + + if popped != originalLast { + t.Error("Pop should return the last element") + } + + // Verify no memory leak - the old last position should be nil + if tt.wantLen > 0 { + // We can't directly check the old slice, but we can verify length is correct + if cap(h) > 0 && len(h) < cap(h) { + // This is expected behavior - capacity unchanged, length reduced + } + } + }) + } +} From 8f183fa4ef9a07de8e4e7d1ac368b4ba5c103a65 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 16:14:51 +0100 Subject: [PATCH 15/18] chore: clean code --- distance.go | 101 ++++++++++++++++++++++++++++++++++++++++ planer.go | 11 ----- planer_test.go | 44 ------------------ state.go | 122 ++++--------------------------------------------- 4 files changed, 109 insertions(+), 169 deletions(-) create mode 100644 distance.go diff --git a/distance.go b/distance.go new file mode 100644 index 0000000..10a5c43 --- /dev/null +++ b/distance.go @@ -0,0 +1,101 @@ +package goapai + +// Distance calculates the distance between the current state value and a condition target value +// Returns 0 if the condition is already satisfied, otherwise returns the numeric distance +func (state State[T]) Distance(condition ConditionInterface) float32 { + // Check if the condition key matches + if state.Key != condition.GetKey() { + return 0 + } + + // Handle different condition types + switch cond := condition.(type) { + case *Condition[int8]: + if v, ok := any(state.Value).(int8); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[int]: + if v, ok := any(state.Value).(int); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[uint8]: + if v, ok := any(state.Value).(uint8); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[uint64]: + if v, ok := any(state.Value).(uint64); ok { + return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) + } + case *Condition[float64]: + if v, ok := any(state.Value).(float64); ok { + return calculateNumericDistance(v, cond.Value, cond.Operator) + } + case *ConditionBool: + if v, ok := any(state.Value).(bool); ok { + if cond.Operator == EQUAL { + if v == cond.Value { + return 0 + } + return 1 + } else if cond.Operator == NOT_EQUAL { + if v != cond.Value { + return 0 + } + return 1 + } + } + case *ConditionString: + if v, ok := any(state.Value).(string); ok { + if cond.Operator == EQUAL { + if v == cond.Value { + return 0 + } + return 1 + } else if cond.Operator == NOT_EQUAL { + if v != cond.Value { + return 0 + } + return 1 + } + } + } + + return 0 +} + +// calculateNumericDistance computes the distance for numeric conditions based on operator +func calculateNumericDistance(current, target float64, op operator) float32 { + switch op { + case EQUAL: + if current < target { + return float32(target - current) + } + return float32(current - target) + case NOT_EQUAL: + if current == target { + return 1.0 + } + return 0.0 + case UPPER_OR_EQUAL: + if current < target { + return float32(target - current) + } + return 0.0 + case UPPER: + if current <= target { + return float32(target - current + 1) + } + return 0.0 + case LOWER_OR_EQUAL: + if current > target { + return float32(current - target) + } + return 0.0 + case LOWER: + if current >= target { + return float32(current - target + 1) + } + return 0.0 + } + return 0.0 +} diff --git a/planer.go b/planer.go index ff283da..d88c6a9 100644 --- a/planer.go +++ b/planer.go @@ -58,14 +58,3 @@ func (agent *Agent) getPrioritizedGoalName() (GoalName, error) { return prioritizedGoalName, fmt.Errorf("no goal available") } } - -// GetNextAction returns the first Action required to achieve the Plan. -// -// An error is returned if no action is available, meaning the Plan is empty. -func (plan Plan) GetNextAction() (Action, error) { - if len(plan) > 0 { - return *plan[0], nil - } - - return Action{}, fmt.Errorf("no action available") -} diff --git a/planer_test.go b/planer_test.go index 6163304..c9aa701 100644 --- a/planer_test.go +++ b/planer_test.go @@ -294,50 +294,6 @@ func TestGetPrioritizedGoalName_UsingSensors(t *testing.T) { } } -// Test GetNextAction -func TestGetNextAction_WithActions(t *testing.T) { - plan := Plan{ - {name: "action1", cost: 1.0}, - {name: "action2", cost: 2.0}, - } - - action, err := plan.GetNextAction() - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if action.name != "action1" { - t.Errorf("Expected 'action1', got '%s'", action.name) - } -} - -func TestGetNextAction_EmptyPlan(t *testing.T) { - plan := Plan{} - - _, err := plan.GetNextAction() - - if err == nil { - t.Error("Expected error for empty plan") - } -} - -func TestGetNextAction_SingleAction(t *testing.T) { - plan := Plan{ - {name: "only_action", cost: 1.0}, - } - - action, err := plan.GetNextAction() - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if action.name != "only_action" { - t.Errorf("Expected 'only_action', got '%s'", action.name) - } -} - // Integration test with complex scenario func TestGetPlan_ComplexScenario(t *testing.T) { actions := Actions{} diff --git a/state.go b/state.go index a462028..0442d09 100644 --- a/state.go +++ b/state.go @@ -46,10 +46,9 @@ type world struct { hash uint64 } -func createState[T Numeric | bool | string](s State[T]) State[T] { - s.hash = s.Hash() - - return s +// Check compares world and states2 by their hash. +func (world world) Check(world2 world) bool { + return world.hash == world2.hash } func (state State[T]) GetKey() StateKey { @@ -129,109 +128,12 @@ func (state State[T]) Hash() uint64 { return hash } -// Distance calculates the distance between the current state value and a condition target value -// Returns 0 if the condition is already satisfied, otherwise returns the numeric distance -func (state State[T]) Distance(condition ConditionInterface) float32 { - // Check if the condition key matches - if state.Key != condition.GetKey() { - return 0 - } - - // Handle different condition types - switch cond := condition.(type) { - case *Condition[int8]: - if v, ok := any(state.Value).(int8); ok { - return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) - } - case *Condition[int]: - if v, ok := any(state.Value).(int); ok { - return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) - } - case *Condition[uint8]: - if v, ok := any(state.Value).(uint8); ok { - return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) - } - case *Condition[uint64]: - if v, ok := any(state.Value).(uint64); ok { - return calculateNumericDistance(float64(v), float64(cond.Value), cond.Operator) - } - case *Condition[float64]: - if v, ok := any(state.Value).(float64); ok { - return calculateNumericDistance(v, cond.Value, cond.Operator) - } - case *ConditionBool: - if v, ok := any(state.Value).(bool); ok { - if cond.Operator == EQUAL { - if v == cond.Value { - return 0 - } - return 1 - } else if cond.Operator == NOT_EQUAL { - if v != cond.Value { - return 0 - } - return 1 - } - } - case *ConditionString: - if v, ok := any(state.Value).(string); ok { - if cond.Operator == EQUAL { - if v == cond.Value { - return 0 - } - return 1 - } else if cond.Operator == NOT_EQUAL { - if v != cond.Value { - return 0 - } - return 1 - } - } - } - - return 0 -} - -// calculateNumericDistance computes the distance for numeric conditions based on operator -func calculateNumericDistance(current, target float64, op operator) float32 { - switch op { - case EQUAL: - if current < target { - return float32(target - current) - } - return float32(current - target) - case NOT_EQUAL: - if current == target { - return 1.0 - } - return 0.0 - case UPPER_OR_EQUAL: - if current < target { - return float32(target - current) - } - return 0.0 - case UPPER: - if current <= target { - return float32(target - current + 1) - } - return 0.0 - case LOWER_OR_EQUAL: - if current > target { - return float32(current - target) - } - return 0.0 - case LOWER: - if current >= target { - return float32(current - target + 1) - } - return 0.0 - } - return 0.0 -} +// updateHashIncremental updates a hash by removing old state and adding new state +func updateHashIncremental(currentHash uint64, oldStateHash, newStateHash uint64) uint64 { + currentHash ^= oldStateHash // Remove old + currentHash ^= newStateHash // Add new -// Check compares world and states2 by their hash. -func (world world) Check(world2 world) bool { - return world.hash == world2.hash + return currentHash } func (statesData states) GetIndex(stateKey StateKey) int { @@ -244,14 +146,6 @@ func (statesData states) GetIndex(stateKey StateKey) int { return -1 } -// updateHashIncremental updates a hash by removing old state and adding new state -func updateHashIncremental(currentHash uint64, oldStateHash, newStateHash uint64) uint64 { - currentHash ^= oldStateHash // Remove old - currentHash ^= newStateHash // Add new - - return currentHash -} - type Sensor any type Sensors map[string]Sensor From 5f50aef452246999a104653b55eafdc94d776719 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 16:16:24 +0100 Subject: [PATCH 16/18] chore: clean benchmark --- benchmark/goapai_test.go | 39 +++++++++++++++------------------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/benchmark/goapai_test.go b/benchmark/goapai_test.go index 89a8ed9..a82f369 100644 --- a/benchmark/goapai_test.go +++ b/benchmark/goapai_test.go @@ -1,11 +1,7 @@ package benchmark import ( - "fmt" "goapai" - "os" - "runtime/pprof" - "runtime/trace" "testing" ) @@ -66,35 +62,30 @@ func BenchmarkGoapAI(b *testing.B) { entity.agent = goapai.CreateAgent(goals, actions) goapai.SetSensor(&entity.agent, "entity", &entity) - goapai.SetState[int](&entity.agent, ATTRIBUTE_2, 0) goapai.SetState[int](&entity.agent, ATTRIBUTE_1, 80) + goapai.SetState[int](&entity.agent, ATTRIBUTE_2, 0) goapai.SetState[int](&entity.agent, ATTRIBUTE_3, 0) //Write to the trace file. - f, _ := os.Create("trace.out") - fcpu, _ := os.Create(`cpu.prof`) - fheap, _ := os.Create(`heap.prof`) - - pprof.StartCPUProfile(fcpu) - pprof.WriteHeapProfile(fheap) - trace.Start(f) + //f, _ := os.Create("trace.out") + //fcpu, _ := os.Create(`cpu.prof`) + //fheap, _ := os.Create(`heap.prof`) + // + //pprof.StartCPUProfile(fcpu) + //pprof.WriteHeapProfile(fheap) + //trace.Start(f) - var lastPlan goapai.Plan for b.Loop() { //goapai.GetPlan(entity.agent, 15) - _, lastPlan = goapai.GetPlan(entity.agent, 15) + goapai.GetPlan(entity.agent, 15) } - fmt.Println("Actions for plan", len(lastPlan)) - for _, j := range lastPlan { - fmt.Printf(" - %v\n", j.GetName()) - } - - defer f.Close() - defer fcpu.Close() - defer fheap.Close() - trace.Stop() - pprof.StopCPUProfile() + //defer f.Close() + //defer fcpu.Close() + //defer fheap.Close() + // + //trace.Stop() + //pprof.StopCPUProfile() b.ReportAllocs() } From 845b0f2993c79f21544502b188d5bc5a6159bb1c Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 16:24:08 +0100 Subject: [PATCH 17/18] chore: code documentation --- action.go | 56 ++++++++++++++++++++++----- agent.go | 86 +++++++++++++++++++++++++++++++++++++++++ distance.go | 13 ++++++- state.go | 109 ++++++++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 237 insertions(+), 27 deletions(-) diff --git a/action.go b/action.go index 89406a2..bae50cd 100644 --- a/action.go +++ b/action.go @@ -15,6 +15,11 @@ const ( DIVIDE ) +// Action represents a single action that an agent can perform to modify the world state. +// +// An action has preconditions (conditions) that must be met before it can be executed, +// and postconditions (effects) that describe how it modifies the world state. +// Actions have a cost that is used by the A* algorithm to find the optimal plan. type Action struct { name string cost float32 @@ -22,8 +27,18 @@ type Action struct { conditions Conditions effects Effects } + +// Actions is a collection of Action pointers. type Actions []*Action +// AddAction creates a new Action and appends it to the Actions collection. +// +// Parameters: +// - name: unique identifier for the action +// - cost: numeric cost used by pathfinding (lower costs are preferred) +// - repeatable: if false, the action can only be used once per plan +// - conditions: preconditions that must be satisfied before the action can be executed +// - effects: postconditions that describe how the action modifies the world state func (actions *Actions) AddAction(name string, cost float32, repeatable bool, conditions Conditions, effects Effects) { action := Action{ name: name, @@ -36,26 +51,35 @@ func (actions *Actions) AddAction(name string, cost float32, repeatable bool, co *actions = append(*actions, &action) } +// GetName returns the action's name identifier. func (action *Action) GetName() string { return action.name } +// GetEffects returns the action's effects (postconditions). func (action *Action) GetEffects() Effects { return action.effects } +// EffectInterface defines the interface that all effect types must implement. +// Effects describe how an action modifies the world state. type EffectInterface interface { GetKey() StateKey check(w world) bool apply(w *world) error } +// Effect represents a numeric state modification for types constrained by Numeric. +// +// It supports arithmetic operators (SET, ADD, SUBTRACT, MULTIPLY, DIVIDE) to modify +// numeric state values. The effect is applied when an action is executed during planning. type Effect[T Numeric] struct { - Key StateKey - Operator arithmetic - Value T + Key StateKey // State key to modify + Operator arithmetic // Arithmetic operation to perform + Value T // Value to use in the operation } +// GetKey returns the state key that this effect modifies. func (effect Effect[T]) GetKey() StateKey { return effect.Key } @@ -114,12 +138,17 @@ func (effect Effect[T]) apply(w *world) error { return nil } +// EffectBool represents a boolean state modification. +// +// Only the SET operator is allowed for boolean effects. Attempting to use other +// operators (ADD, SUBTRACT, etc.) will result in an error when the effect is applied. type EffectBool struct { - Key StateKey - Value bool - Operator arithmetic + Key StateKey // State key to modify + Value bool // Boolean value to set + Operator arithmetic // Must be SET } +// GetKey returns the state key that this effect modifies. func (effectBool EffectBool) GetKey() StateKey { return effectBool.Key } @@ -165,12 +194,17 @@ func (effectBool EffectBool) apply(w *world) error { return nil } +// EffectString represents a string state modification. +// +// Supports SET (replace string) and ADD (concatenate) operators. Other operators +// (SUBTRACT, MULTIPLY, DIVIDE) are not allowed and will result in an error. type EffectString struct { - Key StateKey - Value string - Operator arithmetic + Key StateKey // State key to modify + Value string // String value to use + Operator arithmetic // Allowed: SET, ADD } +// GetKey returns the state key that this effect modifies. func (effectString EffectString) GetKey() StateKey { return effectString.Key } @@ -216,8 +250,12 @@ func (effectString EffectString) apply(w *world) error { return nil } +// EffectFn is a function type for custom procedural effects that directly modify the agent. +// This allows for effects that cannot be expressed through simple state modifications. type EffectFn func(agent *Agent) +// Effects is a collection of EffectInterface implementations that describe how +// an action modifies the world state. type Effects []EffectInterface // If all the effects already exist in world, diff --git a/agent.go b/agent.go index 8ad3db7..9094ed1 100644 --- a/agent.go +++ b/agent.go @@ -1,5 +1,62 @@ +// Package goapai implements a microlithic Goal-Oriented Action Planning (GOAP) system for AI agents. +// +// GOAP is a planning technique where agents dynamically generate action sequences to achieve +// goals based on the current world state. This implementation uses A* pathfinding to find +// optimal plans. +// +// # Key Concepts +// +// Agent: The central entity that maintains world state, goals, actions, and sensors. +// +// State: Key-value pairs representing the world state. States support numeric types, +// booleans, and strings, identified by compact StateKey (uint16) values. +// +// Action: Operations that modify the world state. Each action has preconditions (Conditions) +// and postconditions (Effects), plus a cost for pathfinding optimization. +// +// Goal: Desired world states with priority functions. The planner always works on the +// highest priority goal. +// +// Sensor: External data sources used in goal prioritization and procedural conditions +// without duplicating data during planning. +// +// # Example Usage +// +// // Create an agent +// agent := goapai.CreateAgent( +// goapai.Goals{ +// "survive": { +// Conditions: goapai.Conditions{ +// &goapai.Condition[int]{Key: 1, Value: 50, Operator: goapai.UPPER_OR_EQUAL}, +// }, +// PriorityFn: func(sensors goapai.Sensors) float32 { +// return 1.0 +// }, +// }, +// }, +// goapai.Actions{}, +// ) +// +// // Set initial state +// goapai.SetState[int](&agent, 1, 20) // health = 20 +// +// // Add actions +// agent.actions.AddAction("heal", 1.0, false, +// goapai.Conditions{}, +// goapai.Effects{ +// goapai.Effect[int]{Key: 1, Operator: goapai.ADD, Value: 50}, +// }, +// ) +// +// // Generate plan +// goalName, plan := goapai.GetPlan(agent, 10) package goapai +// Agent represents an AI agent that uses GOAP (Goal-Oriented Action Planning) to make decisions. +// +// An agent maintains a world state, a set of goals with priorities, available actions, +// and sensors for external data. The agent uses A* pathfinding to generate optimal plans +// that achieve its highest priority goal. type Agent struct { actions Actions w world @@ -11,9 +68,17 @@ type goalInterface struct { Conditions []ConditionInterface PriorityFn GoalPriorityFn } + +// GoalName is a unique identifier for a goal. type GoalName string + +// Goals is a map of goal names to their definitions (conditions and priority function). type Goals map[GoalName]goalInterface +// CreateAgent creates and initializes a new Agent with the given goals and actions. +// +// The agent is initialized with an empty world state and no sensors. Use SetState +// to initialize the world state and SetSensor to add sensor data. func CreateAgent(goals Goals, actions Actions) Agent { agent := Agent{ actions: actions, @@ -30,6 +95,17 @@ func CreateAgent(goals Goals, actions Actions) Agent { return agent } +// SetState adds or updates a state value in the agent's world state. +// +// State values can be numeric types (int, int8, uint8, uint64, float64), bool, or string. +// Each state is identified by a unique StateKey. Multiple calls with the same key will +// create duplicate states; this is generally not recommended. +// +// Example: +// +// SetState[int](&agent, 1, 100) // Set state key 1 to integer 100 +// SetState[bool](&agent, 2, true) // Set state key 2 to boolean true +// SetState[string](&agent, 3, "foo") // Set state key 3 to string "foo" func SetState[T Numeric | bool | string](agent *Agent, key StateKey, value T) { agent.w.states = append(agent.w.states, State[T]{ Key: key, @@ -37,6 +113,16 @@ func SetState[T Numeric | bool | string](agent *Agent, key StateKey, value T) { }) } +// SetSensor adds or updates a sensor value for the agent. +// +// Sensors provide external data that can be used in goal priority functions and +// procedural conditions (ConditionFn) without duplicating data during planning. +// Unlike world state, sensors are not modified during plan simulation. +// +// Example: +// +// SetSensor(&agent, "health", 100) +// SetSensor(&agent, "enemy_visible", true) func SetSensor[T Sensor](agent *Agent, name string, value T) { agent.sensors[name] = value } diff --git a/distance.go b/distance.go index 10a5c43..f562d1a 100644 --- a/distance.go +++ b/distance.go @@ -1,7 +1,16 @@ package goapai -// Distance calculates the distance between the current state value and a condition target value -// Returns 0 if the condition is already satisfied, otherwise returns the numeric distance +// Distance calculates the heuristic distance between the current state value and a condition's target value. +// +// It returns 0 if the condition is already satisfied, otherwise returns the numeric distance +// needed to satisfy the condition. This method is used by the A* algorithm to compute heuristics +// for pathfinding. +// +// For numeric types (int, int8, uint8, uint64, float64), the distance is calculated based on +// the operator type (EQUAL, UPPER, LOWER, etc.). For bool and string types, the distance is +// either 0 (satisfied) or 1 (not satisfied). +// +// If the condition's key doesn't match the state's key, or if types don't match, returns 0. func (state State[T]) Distance(condition ConditionInterface) float32 { // Check if the condition key matches if state.Key != condition.GetKey() { diff --git a/state.go b/state.go index 0442d09..c05576b 100644 --- a/state.go +++ b/state.go @@ -15,12 +15,16 @@ const ( UPPER ) +// Numeric is a constraint that defines the numeric types supported by generic State and Condition. +// Supported types are: int8, int, uint8, uint64, and float64. type Numeric interface { ~int8 | ~int | ~uint8 | ~uint64 | ~float64 } +// StateInterface defines the interface that all state types must implement. +// States represent key-value pairs in the world state, with support for hashing and distance calculation. type StateInterface interface { Check(w world, key StateKey) bool GetKey() StateKey @@ -30,12 +34,19 @@ type StateInterface interface { Hash() uint64 Distance(condition ConditionInterface) float32 } + +// State represents a single key-value pair in the world state. +// +// States can hold numeric types (constrained by Numeric), bool, or string values. +// Each state is identified by a unique StateKey and includes a cached hash for performance. type State[T Numeric | bool | string] struct { - Key StateKey - Value T - hash uint64 + Key StateKey // Unique identifier for this state + Value T // The state's value + hash uint64 // Cached hash value for fast comparison } +// StateKey is a compact 16-bit unsigned integer used to identify states. +// Using uint16 instead of strings reduces memory usage and improves performance. type StateKey uint16 type states []StateInterface @@ -146,23 +157,46 @@ func (statesData states) GetIndex(stateKey StateKey) int { return -1 } +// Sensor is an alias for any type, used to store external data accessed by goal priority +// functions and procedural conditions. type Sensor any + +// Sensors is a map of sensor names to their values, providing external data to the agent +// without duplicating it in the world state during planning. type Sensors map[string]Sensor +// GetSensor retrieves a sensor value by name. +// Returns nil if the sensor doesn't exist. func (sensors Sensors) GetSensor(name string) Sensor { return sensors[name] } +// ConditionInterface defines the interface that all condition types must implement. +// Conditions are preconditions that must be satisfied for actions or goals. type ConditionInterface interface { GetKey() StateKey Check(w world) bool } +// ConditionFn represents a procedural condition that evaluates against sensor data. +// +// Unlike state-based conditions, ConditionFn uses a custom function to check sensors. +// The result is cached after the first evaluation to avoid redundant computation during planning. +// +// Example: +// +// condition := &ConditionFn{ +// Key: 100, +// CheckFn: func(sensors Sensors) bool { +// health := sensors["health"].(int) +// return health < 50 +// }, +// } type ConditionFn struct { - Key StateKey - CheckFn func(sensors Sensors) bool - resolved bool - valid bool + Key StateKey // Unique identifier for this condition + CheckFn func(sensors Sensors) bool // Function that evaluates the condition + resolved bool // Whether the condition has been evaluated + valid bool // Cached result of the evaluation } func (conditionFn *ConditionFn) GetKey() StateKey { @@ -178,10 +212,23 @@ func (conditionFn *ConditionFn) Check(w world) bool { return conditionFn.valid } +// Condition represents a numeric state-based condition with comparison operators. +// +// Conditions check if a state value satisfies a comparison (EQUAL, UPPER, LOWER, etc.) +// against a target value. Supported types are constrained by the Numeric interface. +// +// Example: +// +// // Check if state key 1 is greater than or equal to 100 +// condition := &Condition[int]{ +// Key: 1, +// Value: 100, +// Operator: UPPER_OR_EQUAL, +// } type Condition[T Numeric] struct { - Key StateKey - Value T - Operator operator + Key StateKey // State key to check + Value T // Target value to compare against + Operator operator // Comparison operator (EQUAL, UPPER, LOWER, etc.) } func (condition *Condition[T]) GetKey() StateKey { @@ -226,10 +273,23 @@ func (condition *Condition[T]) Check(w world) bool { return false } +// ConditionBool represents a boolean state-based condition. +// +// Only EQUAL and NOT_EQUAL operators are supported for boolean conditions. +// Other operators will cause Check to return false. +// +// Example: +// +// // Check if state key 2 is true +// condition := &ConditionBool{ +// Key: 2, +// Value: true, +// Operator: EQUAL, +// } type ConditionBool struct { - Key StateKey - Value bool - Operator operator + Key StateKey // State key to check + Value bool // Target boolean value + Operator operator // Allowed: EQUAL, NOT_EQUAL } func (conditionBool *ConditionBool) GetKey() StateKey { @@ -260,12 +320,26 @@ func (conditionBool *ConditionBool) Check(w world) bool { return false } +// ConditionString represents a string state-based condition. +// +// Only EQUAL and NOT_EQUAL operators are supported for string conditions. +// Other operators will cause Check to return false. +// +// Example: +// +// // Check if state key 3 equals "ready" +// condition := &ConditionString{ +// Key: 3, +// Value: "ready", +// Operator: EQUAL, +// } type ConditionString struct { - Key StateKey - Value string - Operator operator + Key StateKey // State key to check + Value string // Target string value + Operator operator // Allowed: EQUAL, NOT_EQUAL } +// GetKey returns the state key that this condition checks. func (conditionString *ConditionString) GetKey() StateKey { return conditionString.Key } @@ -294,8 +368,11 @@ func (conditionString *ConditionString) Check(w world) bool { return false } +// Conditions is a collection of ConditionInterface implementations that must all be satisfied. type Conditions []ConditionInterface +// Check returns true if all conditions in the collection are satisfied in the given world state. +// Returns true for an empty condition list (vacuous truth). func (conditions Conditions) Check(w world) bool { for _, condition := range conditions { if !condition.Check(w) { From 4e15eb115e5689aaac858913365ba80807f51f33 Mon Sep 17 00:00:00 2001 From: Ramine Agoune Date: Thu, 30 Oct 2025 16:40:30 +0100 Subject: [PATCH 18/18] chore: fix unit tests --- effect_test.go | 104 ++++++++++++++++++++++++++----------------------- heap_test.go | 28 +++++-------- 2 files changed, 65 insertions(+), 67 deletions(-) diff --git a/effect_test.go b/effect_test.go index 40d80f8..1a9c146 100644 --- a/effect_test.go +++ b/effect_test.go @@ -11,45 +11,43 @@ func TestEffectString_GetKey(t *testing.T) { func TestEffectString_Check(t *testing.T) { tests := []struct { - name string - stateVal string - effectKey StateKey - effectVal string - operator arithmetic - want bool + name string + stateVal string + effectKey StateKey + effectVal string + operator arithmetic + want bool }{ { - name: "SET match", - stateVal: "hello", - effectKey: 1, - effectVal: "hello", - operator: SET, - want: true, + name: "SET match", + stateVal: "hello", + effectKey: 1, + effectVal: "hello", + operator: SET, + want: true, }, { - name: "SET no match", - stateVal: "hello", - effectKey: 1, - effectVal: "world", - operator: SET, - want: false, + name: "SET no match", + stateVal: "hello", + effectKey: 1, + effectVal: "world", + operator: SET, + want: false, }, { - name: "key not found", - stateVal: "hello", - effectKey: 99, - effectVal: "test", - operator: SET, - want: false, + name: "key not found", + stateVal: "hello", + effectKey: 99, + effectVal: "test", + operator: SET, + want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) - if tt.effectKey == tt.effectKey { // Only set state if key matches - SetState[string](&agent, 1, tt.stateVal) - } + SetState[string](&agent, 1, tt.stateVal) effect := EffectString{Key: tt.effectKey, Value: tt.effectVal, Operator: tt.operator} got := effect.check(agent.w) @@ -183,24 +181,24 @@ func TestEffectBool_GetKey(t *testing.T) { func TestEffectBool_Apply_Errors(t *testing.T) { tests := []struct { - name string - operator arithmetic - wantErr bool + name string + operator arithmetic + wantErr bool }{ { - name: "SET allowed", - operator: SET, - wantErr: false, + name: "SET allowed", + operator: SET, + wantErr: false, }, { - name: "ADD not allowed", - operator: ADD, - wantErr: true, + name: "ADD not allowed", + operator: ADD, + wantErr: true, }, { - name: "SUBSTRACT not allowed", - operator: SUBSTRACT, - wantErr: true, + name: "SUBSTRACT not allowed", + operator: SUBSTRACT, + wantErr: true, }, } @@ -224,14 +222,16 @@ func TestEffectBool_Apply_Errors(t *testing.T) { func TestEffect_Apply_AllOperators(t *testing.T) { tests := []struct { - name string - initial int - operator arithmetic - value int - wantVal int - wantErr bool + key StateKey + name string + initial int + operator arithmetic + value int + wantVal int + wantErr bool }{ { + key: 1, name: "SET", initial: 100, operator: SET, @@ -239,6 +239,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: 50, }, { + key: 1, name: "ADD", initial: 100, operator: ADD, @@ -246,6 +247,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: 150, }, { + key: 1, name: "SUBSTRACT", initial: 100, operator: SUBSTRACT, @@ -253,6 +255,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: 70, }, { + key: 1, name: "MULTIPLY", initial: 10, operator: MULTIPLY, @@ -260,6 +263,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: 50, }, { + key: 1, name: "DIVIDE", initial: 100, operator: DIVIDE, @@ -267,6 +271,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: 25, }, { + key: 1, name: "ADD on non-existing key", initial: 0, // Not set operator: ADD, @@ -274,6 +279,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: 50, }, { + key: 1, name: "SUBSTRACT on non-existing key", initial: 0, // Not set operator: SUBSTRACT, @@ -281,6 +287,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantVal: -50, }, { + key: 99, name: "MULTIPLY on non-existing key error", initial: 0, // Not set operator: MULTIPLY, @@ -288,6 +295,7 @@ func TestEffect_Apply_AllOperators(t *testing.T) { wantErr: true, }, { + key: 99, name: "DIVIDE on non-existing key error", initial: 0, // Not set operator: DIVIDE, @@ -299,11 +307,9 @@ func TestEffect_Apply_AllOperators(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { agent := CreateAgent(Goals{}, Actions{}) - if tt.initial != 0 || (tt.operator != ADD && tt.operator != SUBSTRACT && tt.operator != MULTIPLY && tt.operator != DIVIDE) { - SetState[int](&agent, 1, tt.initial) - } + SetState[int](&agent, 1, tt.initial) - effect := Effect[int]{Key: 1, Value: tt.value, Operator: tt.operator} + effect := Effect[int]{Key: tt.key, Value: tt.value, Operator: tt.operator} err := effect.apply(&agent.w) if tt.wantErr { diff --git a/heap_test.go b/heap_test.go index 03850b6..9403c81 100644 --- a/heap_test.go +++ b/heap_test.go @@ -4,10 +4,10 @@ import "testing" func TestNodeHeap_Less(t *testing.T) { tests := []struct { - name string - nodes nodeHeap - i, j int - wantLess bool + name string + nodes nodeHeap + i, j int + wantLess bool }{ { name: "first node has lower cost", @@ -53,12 +53,12 @@ func TestNodeHeap_Less(t *testing.T) { func TestNodeHeap_Swap(t *testing.T) { tests := []struct { - name string - initialNodes []*node - i, j int - wantI, wantJ int - wantIIndex int - wantJIndex int + name string + initialNodes []*node + i, j int + wantI, wantJ int + wantIIndex int + wantJIndex int }{ { name: "swap first and second", @@ -208,14 +208,6 @@ func TestNodeHeap_Pop(t *testing.T) { if popped != originalLast { t.Error("Pop should return the last element") } - - // Verify no memory leak - the old last position should be nil - if tt.wantLen > 0 { - // We can't directly check the old slice, but we can verify length is correct - if cap(h) > 0 && len(h) < cap(h) { - // This is expected behavior - capacity unchanged, length reduced - } - } }) } }