Skip to content

Commit 964610d

Browse files
committed
changes compile, but tests fail
1 parent fad33e4 commit 964610d

File tree

10 files changed

+72
-77
lines changed

10 files changed

+72
-77
lines changed

include/maxplus/base/fsm/fsm.h

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ class State : public WithUniqueID {
153153

154154
void insertOutgoingEdge(Edge &e) { this->outgoingEdges.insert(&e); }
155155

156-
void removeOutgoingEdge(Edge &e) {
157-
this->outgoingEdges.erase(&e); }
156+
void removeOutgoingEdge(EdgeRef e) {
157+
this->outgoingEdges.erase(e);
158+
}
158159

159160
private:
160161
SetOfEdgeRefs outgoingEdges;
@@ -203,6 +204,7 @@ class FiniteStateMachine {
203204
[[nodiscard]] virtual StateRef getInitialState() const = 0;
204205
[[nodiscard]] virtual const SetOfStateRefs &getInitialStates() const = 0;
205206
[[nodiscard]] virtual const SetOfStateRefs &getFinalStates() const = 0;
207+
206208
};
207209

208210
//
@@ -377,6 +379,10 @@ class SetOfStates : public Abstract::SetOfStates {
377379
throw CException("error - state not found in FiniteStateMachine::_withLabel");
378380
}
379381

382+
bool hasStateWithLabel(StateLabelType l) {
383+
return this->labelIndex.find(l) != this->labelIndex.end();
384+
}
385+
380386
void addState(std::shared_ptr<State<StateLabelType, EdgeLabelType>>& s) {
381387
this->addToStateIndex(s->getLabel(), s);
382388
}
@@ -386,7 +392,6 @@ class SetOfStates : public Abstract::SetOfStates {
386392
template <typename StateLabelType, typename EdgeLabelType>
387393
class SetOfStateRefs : public Abstract::SetOfStateRefs {
388394
public:
389-
using CIter = typename SetOfStateRefs::const_iterator;
390395
};
391396

392397
template <typename StateLabelType, typename EdgeLabelType> class State : public Abstract::State {
@@ -439,10 +444,13 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
439444
SetOfStateRefs<StateLabelType, EdgeLabelType> finalStates;
440445

441446
State<StateLabelType, EdgeLabelType>& _getStateLabeled(const StateLabelType &s) {
442-
// try the index first
443447
return this->states.withLabel(s);
444448
};
445449

450+
State<StateLabelType, EdgeLabelType>& _getState(const State<StateLabelType, EdgeLabelType> &s) {
451+
return dynamic_cast<State<StateLabelType, EdgeLabelType>&>(this->states.withId(s.getId()));
452+
};
453+
446454

447455
public:
448456
FiniteStateMachine() : Abstract::FiniteStateMachine() {};
@@ -486,8 +494,8 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
486494
void removeEdge(const Edge<StateLabelType, EdgeLabelType> &e) {
487495
auto csrc = dynamic_cast<StateRef<StateLabelType, EdgeLabelType>>(e.getSource());
488496
// get a non-const version of the state
489-
auto& src = this->_getStateLabeled(csrc->getLabel());
490-
src.removeOutgoingEdge(e);
497+
auto& src = this->_getState(*csrc);
498+
src.removeOutgoingEdge(&e);
491499
this->edges.remove(e);
492500
}
493501

@@ -510,7 +518,7 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
510518

511519
// set initial state to state with label;
512520
void setInitialState(StateLabelType label) {
513-
this->setInitialState(*this->states.withLabel(label));
521+
this->setInitialState(this->states.withLabel(label));
514522
};
515523

516524
void setInitialState(const State<StateLabelType, EdgeLabelType> &s) {
@@ -526,7 +534,7 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
526534

527535
void addFinalState(const State<StateLabelType, EdgeLabelType> &s) {
528536
// we are assuming s is one of our states
529-
this->finalStates.emplace(s.getId(), &s);
537+
this->finalStates.insert(&s);
530538
};
531539

532540

@@ -598,20 +606,9 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
598606
return nullptr;
599607
};
600608

601-
StateRef<StateLabelType, EdgeLabelType> checkStateLabeled(const StateLabelType &s) {
602-
// try the index first
603-
if (this->states.withLabel(s) != nullptr) {
604-
return this->states.withLabel(s);
605-
}
606-
// for now just a linear search
607-
auto i = this->states.begin();
608-
while (i != this->states.end()) {
609-
auto &t = dynamic_cast<State<StateLabelType, EdgeLabelType> &>(*((*i).second));
610-
if ((t.stateLabel) == s) {
611-
this->states.addToStateIndex(s, &t);
612-
return &t;
613-
}
614-
i++;
609+
StateRef<StateLabelType, EdgeLabelType> checkStateLabeled(const StateLabelType &l) {
610+
if (this->states.hasStateWithLabel(l)) {
611+
return &(this->states.withLabel(l));
615612
}
616613
return nullptr;
617614
};
@@ -642,9 +639,8 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
642639

643640
for (const auto &it : outgoingEdges) {
644641
auto e = dynamic_cast<const Edge<StateLabelType, EdgeLabelType> *>(it);
645-
auto dstState = dynamic_cast<const State<StateLabelType, EdgeLabelType> &>(
646-
e->getDestination());
647-
if (e->getLabel() == lbl && dstState.getLabel() == dst) {
642+
auto dstState = dynamic_cast<StateRef<StateLabelType, EdgeLabelType>>(e->getDestination());
643+
if (e->getLabel() == lbl && dstState->getLabel() == dst) {
648644
return e;
649645
}
650646
}
@@ -861,7 +857,7 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
861857
this->newInstance());
862858

863859
// make a state for every equivalence class
864-
std::map<std::shared_ptr<Abstract::SetOfStateRefs>, State<StateLabelType, EdgeLabelType> *>
860+
std::map<std::shared_ptr<Abstract::SetOfStateRefs>, StateRef<StateLabelType, EdgeLabelType>>
865861
newStateMap;
866862
CId sid = 0;
867863
for (const auto &cli : eqClasses) {
@@ -883,7 +879,7 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
883879
auto ed = dynamic_cast<EdgeRef<StateLabelType, EdgeLabelType>>(edi);
884880
result->addEdge(*(newStateMap[cli]),
885881
ed->getLabel(),
886-
*(newStateMap[eqMap[&(ed->getDestination())]]));
882+
*(newStateMap[eqMap[ed->getDestination()]]));
887883
}
888884
}
889885

include/maxplus/game/mpgameautomaton.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,21 @@ class MaxPlusGameAutomatonWithRewards : public MaxPlusAutomatonWithRewards,
6161
~MaxPlusGameAutomatonWithRewards() override = default;
6262
;
6363

64-
std::set<MPARState *> &getV0() override { return this->setV0; }
64+
std::set<MPARStateRef> &getV0() override { return this->setV0; }
6565

66-
std::set<MPARState *> &getV1() override { return this->setV1; }
66+
std::set<MPARStateRef> &getV1() override { return this->setV1; }
6767

68-
void addV0(MPARState *s) { this->setV0.insert(s); }
68+
void addV0(MPARStateRef s) { this->setV0.insert(s); }
6969

70-
void addV1(MPARState *s) { this->setV1.insert(s); }
70+
void addV1(MPARStateRef s) { this->setV1.insert(s); }
7171

72-
MPTime getWeight1(const MPAREdge *e) const override { return MPTime(e->getLabel().reward); }
72+
MPTime getWeight1(const MPAREdgeRef e) const override { return MPTime(e->getLabel().reward); }
7373

74-
MPTime getWeight2(const MPAREdge *e) const override { return e->getLabel().delay; }
74+
MPTime getWeight2(const MPAREdgeRef e) const override { return e->getLabel().delay; }
7575

7676
private:
77-
std::set<MPARState *> setV0;
78-
std::set<MPARState *> setV1;
77+
std::set<MPARStateRef> setV0;
78+
std::set<MPARStateRef> setV1;
7979
};
8080
} // namespace MaxPlus
8181

include/maxplus/game/policyiteration.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ template <typename SL, typename EL> class PolicyIteration {
138138
// Initialize state ids.
139139
std::map<const State<SL, EL> *, CDouble> stateIds;
140140
int cid = 0;
141-
typename SetOfStates<SL, EL>::CIter si;
142141
for (auto &it : states) {
143142
auto &si = *(it.second);
144143
// Source vertex.
@@ -160,16 +159,16 @@ template <typename SL, typename EL> class PolicyIteration {
160159
dw2vector = result.dw2_i_t;
161160

162161
// Improve the strategy of player 0, just one iteration.
163-
std::set<State<SL, EL> *> &states = game.getV0();
162+
std::set<StateRef<SL, EL>> &states = game.getV0();
164163
for (auto &si : states) {
165164
// Source vertex.
166-
auto v = dynamic_cast<State<SL, EL> *>(si);
165+
auto v = dynamic_cast<StateRef<SL, EL>>(si);
167166
// Outgoing edges.
168167
auto es = dynamic_cast<const FSM::Abstract::SetOfEdgeRefs &>(v->getOutgoingEdges());
169168
for (const auto &ei : es) {
170169
auto e = dynamic_cast<EdgeRef<SL, EL>>(ei);
171170

172-
const auto &u = dynamic_cast<const State<SL, EL> *>(&(e->getDestination()));
171+
const auto &u = dynamic_cast<const StateRef<SL, EL>>(e->getDestination());
173172

174173
CDouble mw = ratioVector[u];
175174
CDouble w1 = static_cast<CDouble>(game.getWeight1(e));
@@ -251,18 +250,18 @@ template <typename SL, typename EL> class PolicyIteration {
251250
r_i_t = evalResult.r_i_t;
252251
dw2_i_t = evalResult.dw2_i_t;
253252

254-
std::set<State<SL, EL> *> &states = game.getV1();
253+
std::set<StateRef<SL, EL>> &states = game.getV1();
255254
for (auto &si : states) {
256255
// Source vertex.
257-
auto v = dynamic_cast<State<SL, EL> *>(si);
256+
auto v = dynamic_cast<StateRef<SL, EL>>(si);
258257

259258
// Outgoing edges.
260259
const auto &es =
261260
dynamic_cast<const FSM::Abstract::SetOfEdgeRefs &>(v->getOutgoingEdges());
262261
for (const auto &ei : es) {
263262
auto e = dynamic_cast<EdgeRef<SL, EL>>(ei);
264263

265-
auto u = dynamic_cast<const State<SL, EL> *>(&(e->getDestination()));
264+
auto u = dynamic_cast<const StateRef<SL, EL>>(e->getDestination());
266265

267266
CDouble cycleRatio = r_i_t[u];
268267
CDouble w1 = static_cast<CDouble>(game.getWeight1(e));

include/maxplus/game/ratiogame.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ template <typename SL, typename EL> class RatioGame : virtual public DoubleWeigh
5252
public:
5353
virtual inline ~RatioGame() = default;
5454

55-
virtual std::set<State<SL, EL> *> &getV0() = 0;
55+
virtual std::set<StateRef<SL, EL>> &getV0() = 0;
5656

57-
virtual std::set<State<SL, EL> *> &getV1() = 0;
57+
virtual std::set<StateRef<SL, EL>> &getV1() = 0;
5858
};
5959

6060
}; // namespace MaxPlus

include/maxplus/graph/mpautomaton.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class MaxPlusAutomatonWithRewards
235235
// compute the maximum cycle ratio of delay over progress
236236
CDouble calculateMCR();
237237
// compute the maximum cycle ratio of delay over progress and also return a critical cycle
238-
CDouble calculateMCRAndCycle(std::shared_ptr<std::vector<const MPAREdge *>> *cycle);
238+
CDouble calculateMCRAndCycle(std::shared_ptr<std::vector<MPAREdgeRef>> *cycle);
239239
};
240240

241241
} // namespace MaxPlus

src/graph/mpautomaton.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ CDouble MaxPlusAutomatonWithRewards::calculateMCR() {
3939
}
4040

4141
CDouble MaxPlusAutomatonWithRewards::calculateMCRAndCycle(
42-
std::shared_ptr<std::vector<const MPAREdge *>> *cycle) {
42+
std::shared_ptr<std::vector<MPAREdgeRef>> *cycle) {
4343

4444
MCMgraph g;
4545

@@ -52,7 +52,7 @@ CDouble MaxPlusAutomatonWithRewards::calculateMCRAndCycle(
5252
}
5353

5454
CId eId = 0;
55-
std::map<const MCMedge *, const MPAREdge *> edgeMap;
55+
std::map<const MCMedge *, MPAREdgeRef> edgeMap;
5656

5757
for (const auto &s : this->getStates()) {
5858
for (auto *e : (s.second)->getOutgoingEdges()) {
@@ -68,7 +68,7 @@ CDouble MaxPlusAutomatonWithRewards::calculateMCRAndCycle(
6868
std::vector<const MCMedge *> mcmCycle;
6969
CDouble mcr = maxCycleRatioAndCriticalCycleYoungTarjanOrlin(g, &mcmCycle);
7070
if (cycle != nullptr) {
71-
*cycle = std::make_shared<std::vector<const MPAREdge *>>();
71+
*cycle = std::make_shared<std::vector<MPAREdgeRef>>();
7272
for (const auto *e : mcmCycle) {
7373
(*cycle)->push_back(edgeMap[e]);
7474
}

src/graph/smpls.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ namespace MaxPlus::SMPLS {
220220
// create the states needed per transition matrix
221221
for (unsigned int k = 0; k < nrTokens; k++)
222222
{
223-
MPAState* s = mpa->addState(makeMPAStateLabel(qId, k));
223+
MPAStateRef s = mpa->addState(makeMPAStateLabel(qId, k));
224224
//std::cout << "DEBUG adding state to mpa id: " << (CString)(s->getLabel().id) << ", tkn: " << (s->getLabel().tokenNr)<< std::endl;
225225
if (isInitial) {
226226
mpa->addInitialState(*s);

src/testbench/game/policyiterationtest.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ void PolicyIterationTest::testPlayer1CycleTest() {
3131
// One FSM state, four tokens:
3232
CId fsm_s0 = 0;
3333

34-
MPARState *s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
35-
MPARState *s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
36-
MPARState *s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
37-
MPARState *s4 = mpa.addState(makeMPAStateLabel(fsm_s0, 3));
34+
MPARStateRef s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
35+
MPARStateRef s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
36+
MPARStateRef s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
37+
MPARStateRef s4 = mpa.addState(makeMPAStateLabel(fsm_s0, 3));
3838
mpa.addV1(s1);
3939
mpa.addV1(s2);
4040
mpa.addV1(s3);
@@ -66,10 +66,10 @@ void PolicyIterationTest::testPlayer1CycleTest2() {
6666
// One FSM state, four tokens:
6767
CId fsm_s0 = 0;
6868

69-
MPARState *s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
70-
MPARState *s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
71-
MPARState *s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
72-
MPARState *s4 = mpa.addState(makeMPAStateLabel(fsm_s0, 3));
69+
MPARStateRef s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
70+
MPARStateRef s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
71+
MPARStateRef s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
72+
MPARStateRef s4 = mpa.addState(makeMPAStateLabel(fsm_s0, 3));
7373

7474
mpa.addV1(s1);
7575
mpa.addV1(s2);
@@ -103,11 +103,11 @@ void PolicyIterationTest::testTwoPlayersTest() {
103103
// One FSM state, four tokens:
104104
CId fsm_s0 = 0;
105105

106-
MPARState *s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
107-
MPARState *s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
108-
MPARState *s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
109-
MPARState *s4 = mpa.addState(makeMPAStateLabel(fsm_s0, 3));
110-
MPARState *s5 = mpa.addState(makeMPAStateLabel(fsm_s0, 4));
106+
MPARStateRef s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
107+
MPARStateRef s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
108+
MPARStateRef s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
109+
MPARStateRef s4 = mpa.addState(makeMPAStateLabel(fsm_s0, 3));
110+
MPARStateRef s5 = mpa.addState(makeMPAStateLabel(fsm_s0, 4));
111111

112112
mpa.addV0(s1);
113113
mpa.addV0(s3);
@@ -146,9 +146,9 @@ void PolicyIterationTest::testSimpleTest() {
146146
// One FSM state, three tokens:
147147
CId fsm_s0 = 0;
148148

149-
MPARState *s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
150-
MPARState *s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
151-
MPARState *s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
149+
MPARStateRef s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
150+
MPARStateRef s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
151+
MPARStateRef s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
152152

153153
mpa.addEdge(*s1, makeRewardEdgeLabel(MPDelay(3.0), CString("A"), 1.0), *s2);
154154
mpa.addEdge(*s1, makeRewardEdgeLabel(MPDelay(3.0), CString("A"), 1.0), *s3);
@@ -172,8 +172,8 @@ void PolicyIterationTest::testInvalidInputGraphTest() {
172172
// One FSM state, three tokens:
173173
CId fsm_s0 = 0;
174174

175-
MPARState *s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
176-
MPARState *s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
175+
MPARStateRef s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
176+
MPARStateRef s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
177177

178178
mpa.addEdge(*s1, makeRewardEdgeLabel(MPDelay(3.0), CString("A"), 1.0), *s2);
179179
mpa.addV0(s1);

src/testbench/game/strategyvectortest.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ void StrategyVectorTest::testSimpleTest() {
2323
// One FSM state, three tokens:
2424
CId fsm_s0 = 0;
2525

26-
MPARState *s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
27-
MPARState *s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
28-
MPARState *s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
26+
MPARStateRef s1 = mpa.addState(makeMPAStateLabel(fsm_s0, 0));
27+
MPARStateRef s2 = mpa.addState(makeMPAStateLabel(fsm_s0, 1));
28+
MPARStateRef s3 = mpa.addState(makeMPAStateLabel(fsm_s0, 2));
2929

3030
mpa.addEdge(*s1, makeRewardEdgeLabel(MPDelay(3.0), CString("A"), 1.0), *s2);
3131
mpa.addEdge(*s1, makeRewardEdgeLabel(MPDelay(3.0), CString("A"), 1.0), *s3);

0 commit comments

Comments
 (0)