Skip to content

Commit 8e2d56b

Browse files
authored
Merge pull request #21 from TimeDelta/codex/write-unit-tests-for-guidedpopulation
Add GuidedPopulation tests
2 parents 0c09181 + 8e8a631 commit 8e2d56b

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

tests/test_population.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
import neat
4+
import torch
5+
import numpy as np
6+
7+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
8+
9+
from population import GuidedPopulation
10+
from genome import OptimizerGenome
11+
from genes import NODE_TYPE_TO_INDEX, ConnectionGene, NodeGene
12+
from attributes import IntAttribute, FloatAttribute
13+
from tasks import RegressionTask
14+
from reproduction import GuidedReproduction
15+
16+
17+
def make_config():
18+
config_path = os.path.join(os.path.dirname(__file__), os.pardir, "neat-config")
19+
return neat.Config(
20+
OptimizerGenome,
21+
GuidedReproduction,
22+
neat.DefaultSpeciesSet,
23+
neat.DefaultStagnation,
24+
config_path,
25+
)
26+
27+
28+
def create_simple_genome(key=0):
29+
genome = OptimizerGenome(key)
30+
ng0 = NodeGene(0, None)
31+
ng0.node_type = "aten::add"
32+
ng0.dynamic_attributes = {IntAttribute("a"): 1}
33+
ng1 = NodeGene(1, None)
34+
ng1.node_type = "aten::mul"
35+
ng1.dynamic_attributes = {FloatAttribute("b"): 0.5}
36+
genome.nodes = {0: ng0, 1: ng1}
37+
cg = ConnectionGene((0, 1))
38+
cg.enabled = True
39+
genome.connections = {(0, 1): cg}
40+
genome.next_node_id = 2
41+
return genome
42+
43+
44+
def test_genome_to_data():
45+
config = make_config()
46+
pop = GuidedPopulation(config)
47+
genome = create_simple_genome()
48+
data = pop.genome_to_data(genome)
49+
50+
assert genome.graph_dict is not None
51+
assert list(data.node_types.tolist()) == [NODE_TYPE_TO_INDEX["aten::add"], NODE_TYPE_TO_INDEX["aten::mul"]]
52+
assert data.edge_index.size(1) == 1
53+
assert data.edge_index[:, 0].tolist() == [0, 1]
54+
assert len(data.node_attributes) == 2
55+
assert "a" in pop.shared_attr_vocab.name_to_index
56+
assert "b" in pop.shared_attr_vocab.name_to_index
57+
58+
59+
def test_generate_guided_offspring():
60+
config = make_config()
61+
pop = GuidedPopulation(config)
62+
pop.guide.decoder.max_nodes = 2
63+
pop.guide.decoder.max_attributes_per_node = 2
64+
65+
g1 = create_simple_genome(0)
66+
g1.fitness = 1.0
67+
g2 = create_simple_genome(1)
68+
g2.fitness = 0.5
69+
pop.genome_to_data(g1)
70+
pop.genome_to_data(g2)
71+
72+
task = RegressionTask.random_init(num_samples=4, silent=True)
73+
offspring = pop.generate_guided_offspring(
74+
task.name(), task.features, [g1, g2], config, n_offspring=2, latent_steps=1
75+
)
76+
77+
assert isinstance(offspring, list)
78+
assert len(offspring) <= 2
79+
for child in offspring:
80+
assert isinstance(child, OptimizerGenome)
81+
assert child.graph_dict is not None

0 commit comments

Comments
 (0)