Skip to content

Commit fad33e4

Browse files
committed
sync
1 parent 02edcca commit fad33e4

File tree

3 files changed

+51
-81
lines changed

3 files changed

+51
-81
lines changed

include/maxplus/base/fsm/fsm.h

Lines changed: 31 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -151,24 +151,26 @@ class State : public WithUniqueID {
151151
return this->outgoingEdges;
152152
}
153153

154-
// TODO (Marc Geilen) get rid of const cast
155-
void removeOutgoingEdge(const Edge &e) {
156-
this->outgoingEdges.erase(&e); }
157154
void insertOutgoingEdge(Edge &e) { this->outgoingEdges.insert(&e); }
158155

156+
void removeOutgoingEdge(Edge &e) {
157+
this->outgoingEdges.erase(&e); }
158+
159159
private:
160160
SetOfEdgeRefs outgoingEdges;
161+
161162
};
162163

163164

164165
// A set of states
165166
// the set is assumed to have unique ownership of the states
166167
class SetOfStates : public std::map<CId, std::shared_ptr<State>> {
167168
public:
168-
using CIter = SetOfStates::const_iterator;
169-
using Iter = SetOfStates::iterator;
170169
void remove(const State &s) { this->erase(s.getId()); }
171170
virtual ~SetOfStates() = default;
171+
State& withId(const CId id) {
172+
return *this->at(id);
173+
}
172174
};
173175

174176
struct StateRefCompareLessThan {
@@ -180,6 +182,7 @@ class SetOfStateRefs : public std::set<const State *, StateRefCompareLessThan> {
180182
public:
181183
using CIter = SetOfEdgeRefs::const_iterator;
182184
bool includesState(const State *s) { return this->find(s) != this->end(); }
185+
virtual ~SetOfStateRefs() = default;
183186
};
184187

185188
// forward declaration of reachable states strategy
@@ -359,20 +362,25 @@ template <typename StateLabelType, typename EdgeLabelType> using StateRef = cons
359362
template <typename StateLabelType, typename EdgeLabelType>
360363
class SetOfStates : public Abstract::SetOfStates {
361364
private:
362-
std::map<StateLabelType, StateRef<StateLabelType, EdgeLabelType>> stateIndex;
365+
std::map<StateLabelType, std::shared_ptr<State<StateLabelType, EdgeLabelType>>> labelIndex;
366+
367+
void addToStateIndex(StateLabelType l, std::shared_ptr<State<StateLabelType, EdgeLabelType>> s) {
368+
this->labelIndex[l] = s;
369+
}
363370

364371
public:
365-
using CIter = typename SetOfStates::const_iterator;
366-
StateRef<StateLabelType, EdgeLabelType> withLabel(StateLabelType l) {
367-
if (this->stateIndex.find(l) != this->stateIndex.end()) {
368-
return this->stateIndex[l];
372+
373+
State<StateLabelType, EdgeLabelType>& withLabel(StateLabelType l) {
374+
if (this->labelIndex.find(l) != this->labelIndex.end()) {
375+
return *this->labelIndex[l];
369376
}
370-
return nullptr;
377+
throw CException("error - state not found in FiniteStateMachine::_withLabel");
371378
}
372379

373-
void addToStateIndex(StateLabelType l, State<StateLabelType, EdgeLabelType> *s) {
374-
this->stateIndex[l] = s;
380+
void addState(std::shared_ptr<State<StateLabelType, EdgeLabelType>>& s) {
381+
this->addToStateIndex(s->getLabel(), s);
375382
}
383+
376384
};
377385

378386
template <typename StateLabelType, typename EdgeLabelType>
@@ -430,27 +438,9 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
430438
SetOfStateRefs<StateLabelType, EdgeLabelType> initialStates;
431439
SetOfStateRefs<StateLabelType, EdgeLabelType> finalStates;
432440

433-
State<StateLabelType, EdgeLabelType> _getStateLabeled(const StateLabelType &s) {
441+
State<StateLabelType, EdgeLabelType>& _getStateLabeled(const StateLabelType &s) {
434442
// try the index first
435-
StateRef<StateLabelType, EdgeLabelType> sp = this->states.withLabel(s);
436-
if (sp != nullptr) {
437-
return sp;
438-
}
439-
// for now just a linear search
440-
// TODO: use index?
441-
for (auto &it : this->states) {
442-
auto &i = *(it.second);
443-
auto &t = dynamic_cast<const State<StateLabelType, EdgeLabelType> &>(i);
444-
// TODO: remove const_cast set of states provides only const iterator, but states are
445-
// identified only by ID.
446-
auto &ct = const_cast<State<StateLabelType, EdgeLabelType> &>(t);
447-
if ((ct.stateLabel) == s) {
448-
// TODO: manage index inside SetOfStates
449-
this->states.addToStateIndex(s, &ct);
450-
return &ct;
451-
}
452-
}
453-
throw CException("error - state not found in FiniteStateMachine::getStateLabeled");
443+
return this->states.withLabel(s);
454444
};
455445

456446

@@ -469,12 +459,11 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
469459
}
470460

471461
// add state with the given label
472-
State<StateLabelType, EdgeLabelType> *addState(StateLabelType label) {
462+
StateRef<StateLabelType, EdgeLabelType> addState(StateLabelType label) {
473463
bool added = false;
474-
typename SetOfStates<StateLabelType, EdgeLabelType>::CIter i;
475464
auto sp = std::make_shared<State<StateLabelType, EdgeLabelType>>(label);
476465
auto &s = *sp;
477-
this->states[sp->getId()] = std::move(sp);
466+
this->states.addState(sp);
478467
return &s;
479468
};
480469

@@ -483,9 +472,9 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
483472
const State<StateLabelType, EdgeLabelType> &dst) {
484473
// lookup state again to drop const qualifier
485474
auto &mySrc =
486-
dynamic_cast<State<StateLabelType, EdgeLabelType> &>(*(this->states[src.getId()]));
475+
dynamic_cast<State<StateLabelType, EdgeLabelType> &>(this->states.withId(src.getId()));
487476
auto &myDst =
488-
dynamic_cast<State<StateLabelType, EdgeLabelType> &>(*(this->states[dst.getId()]));
477+
dynamic_cast<State<StateLabelType, EdgeLabelType> &>(this->states.withId(dst.getId()));
489478
bool added = false;
490479
auto ep = std::make_shared<Edge<StateLabelType, EdgeLabelType>>(mySrc, lbl, myDst);
491480
auto &e = *ep;
@@ -497,8 +486,8 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
497486
void removeEdge(const Edge<StateLabelType, EdgeLabelType> &e) {
498487
auto csrc = dynamic_cast<StateRef<StateLabelType, EdgeLabelType>>(e.getSource());
499488
// get a non-const version of the state
500-
auto src = this->getStateLabeled(csrc->getLabel());
501-
src->removeOutgoingEdge(e);
489+
auto& src = this->_getStateLabeled(csrc->getLabel());
490+
src.removeOutgoingEdge(e);
502491
this->edges.remove(e);
503492
}
504493

@@ -532,7 +521,7 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
532521

533522
void addInitialState(const State<StateLabelType, EdgeLabelType> &s) {
534523
// we are assuming s is one of our states
535-
this->initialStates.emplace(s.getId(), &s);
524+
this->initialStates.insert(&s);
536525
};
537526

538527
void addFinalState(const State<StateLabelType, EdgeLabelType> &s) {
@@ -558,26 +547,7 @@ class FiniteStateMachine : public Abstract::FiniteStateMachine {
558547
};
559548

560549
const StateRef<StateLabelType, EdgeLabelType> getStateLabeled(const StateLabelType &s) {
561-
// try the index first
562-
StateRef<StateLabelType, EdgeLabelType> sp = this->states.withLabel(s);
563-
if (sp != nullptr) {
564-
return sp;
565-
}
566-
// for now just a linear search
567-
// TODO: use index?
568-
for (auto &it : this->states) {
569-
auto &i = *(it.second);
570-
auto &t = dynamic_cast<const State<StateLabelType, EdgeLabelType> &>(i);
571-
// TODO: remove const_cast set of states provides only const iterator, but states are
572-
// identified only by ID.
573-
auto &ct = const_cast<State<StateLabelType, EdgeLabelType> &>(t);
574-
if ((ct.stateLabel) == s) {
575-
// TODO: manage index inside SetOfStates
576-
this->states.addToStateIndex(s, &ct);
577-
return &ct;
578-
}
579-
}
580-
throw CException("error - state not found in FiniteStateMachine::getStateLabeled");
550+
return &(this->_getStateLabeled(s));
581551
};
582552

583553
bool hasStateLabeled(const StateLabelType &s) {

include/maxplus/graph/smpls.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class SMPLSwithEvents : public SMPLS {
161161

162162
void determinizeUtil(const IOAState &s,
163163
IOASetOfStateRefs& visited,
164-
const IOASetOfStates &finalStates,
164+
const IOASetOfStateRefs &finalStates,
165165
CString &errMsg,
166166
std::ofstream &outfile);
167167

src/graph/smpls.cc

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,10 @@ namespace MaxPlus::SMPLS {
262262
if (!MP_IS_MINUS_INFINITY(d))
263263
{
264264
MPAStateRef src = mpa->getStateLabeled(makeMPAStateLabel(q1Id, static_cast<unsigned int>(col)));
265-
MPAState& dst = mpa->getStateLabeled(makeMPAStateLabel(q2Id, static_cast<unsigned int>(row)));
265+
MPAStateRef dst = mpa->getStateLabeled(makeMPAStateLabel(q2Id, static_cast<unsigned int>(row)));
266266
MPAEdgeLabel el = makeMPAEdgeLabel(d, sc);
267267
el.mode = CString(tr->getLabel());
268-
mpa->addEdge(src, el, dst);
268+
mpa->addEdge(*src, el, *dst);
269269
}
270270
}
271271
}
@@ -293,18 +293,18 @@ namespace MaxPlus::SMPLS {
293293

294294
std::ofstream outfile(file);
295295
outfile << "ioautomaton statespace{ \r\n";
296-
const auto& I = dynamic_cast<const IOASetOfStates&>(this->ioa->getInitialStates());
296+
const auto& I = this->ioa->getInitialStates();
297297

298-
const auto& finalStates =dynamic_cast<const IOASetOfStates&>(this->ioa->getFinalStates());
298+
const auto& finalStates = dynamic_cast<const IOASetOfStateRefs&>(this->ioa->getFinalStates());
299299

300300
CString errMsg = "";
301301
auto i = I.begin();
302-
auto& s = dynamic_cast<IOAState&>(*((*i).second));
302+
auto& s = dynamic_cast<const IOAState&>(*(*i));
303303
i++;
304304
//we remove the rest of the initial states since only one is allowed
305305
for (; i != I.end();)
306306
{
307-
ioa->removeState(dynamic_cast<const IOAState&>(*((*i).second)));
307+
ioa->removeState(dynamic_cast<const IOAState&>(*((*i))));
308308
}
309309
IOASetOfStateRefs visitedStates;
310310
determinizeUtil(s, visitedStates, finalStates, errMsg, outfile);
@@ -449,7 +449,7 @@ namespace MaxPlus::SMPLS {
449449
modeName += e->getLabel().first + "," + e->getLabel().second;
450450

451451
// add the edge with unique name between the corresponding states
452-
this->elsFSM.addEdge(this->elsFSM.getStateLabeled(s.getLabel()), modeName, this->elsFSM.getStateLabeled((dynamic_cast<const IOAState&>(e->getDestination())).getLabel()));
452+
this->elsFSM.addEdge(*this->elsFSM.getStateLabeled(s.getLabel()), modeName, *this->elsFSM.getStateLabeled((dynamic_cast<IOAStateRef>(e->getDestination()))->getLabel()));
453453

454454
// make a copy so that child node can not modify the parent nodes list of events
455455
// only adds and removes and passes it to it's children
@@ -586,7 +586,7 @@ namespace MaxPlus::SMPLS {
586586
// we need this later to make all matrices square
587587
biggestMatrixSize = std::max(biggestMatrixSize, std::max(sMatrix->getCols(), sMatrix->getRows()));
588588

589-
prepareMatrices(dynamic_cast<const IOAState&>(e->getDestination()), eList, visitedEdges);
589+
prepareMatrices(*dynamic_cast<IOAStateRef>(e->getDestination()), eList, visitedEdges);
590590
}
591591
}
592592

@@ -683,17 +683,17 @@ namespace MaxPlus::SMPLS {
683683
}
684684
}
685685
}
686-
const auto& s2 = dynamic_cast<const IOAState&>(e->getDestination());
686+
const auto s2 = dynamic_cast<IOAStateRef>(e->getDestination());
687687

688-
isConsistentUtil(s2, eList, finalStates, errMsg, visited);
688+
isConsistentUtil(*s2, eList, finalStates, errMsg, visited);
689689
}
690690
}
691691
}
692692

693693
/**
694694
* recursive part of determinize
695695
*/
696-
void SMPLSwithEvents::determinizeUtil(const IOAState& s, IOASetOfStateRefs& visited, const IOASetOfStates& finalStates, CString& errMsg, std::ofstream& outfile)
696+
void SMPLSwithEvents::determinizeUtil(const IOAState& s, IOASetOfStateRefs& visited, const IOASetOfStateRefs& finalStates, CString& errMsg, std::ofstream& outfile)
697697
{
698698
/**
699699
* Deterministic IOA is defined with:
@@ -715,8 +715,8 @@ namespace MaxPlus::SMPLS {
715715
InputAction input = e->getLabel().first;
716716
if (input.empty())
717717
{
718-
const auto& s2 = dynamic_cast<const IOAState&>(e->getDestination());
719-
outfile << s.stateLabel << "-," << e->getLabel().second << "->" << s2.stateLabel;
718+
const auto s2 = dynamic_cast<IOAStateRef>(e->getDestination());
719+
outfile << s.stateLabel << "-," << e->getLabel().second << "->" << s2->stateLabel;
720720
ioa->removeEdge(*e);
721721

722722
i++;
@@ -729,13 +729,13 @@ namespace MaxPlus::SMPLS {
729729
//ioa->removeState(dynamic_cast<IOAState*>(e->getDestination()));
730730
}
731731

732-
if (finalStates.count(s2.getId())>0)
732+
if (finalStates.count(s2)>0)
733733
{
734734
outfile << " f\n";
735735
}
736736
else {
737737
outfile << "\n";
738-
determinizeUtil(s2, visited, finalStates, errMsg, outfile);
738+
determinizeUtil(*s2, visited, finalStates, errMsg, outfile);
739739
}
740740
}
741741
else // we have an input action
@@ -758,17 +758,17 @@ namespace MaxPlus::SMPLS {
758758
// only allow edges with the outcome of the same event
759759
if (this->findEventByOutcome(input) == ev)
760760
{
761-
const auto& s2 = dynamic_cast<const IOAState&>(e->getDestination());
762-
outfile << s.stateLabel << "-" << e->getLabel().first << "," << e->getLabel().second << "->" << s2.stateLabel;
761+
const auto s2 = dynamic_cast<IOAStateRef>(e->getDestination());
762+
outfile << s.stateLabel << "-" << e->getLabel().first << "," << e->getLabel().second << "->" << s2->stateLabel;
763763

764764
ioa->removeEdge(*e);
765-
if (finalStates.count(s2.getId())>0)
765+
if (finalStates.count(s2)>0)
766766
{
767767
outfile << " f\n";
768768
}
769769
else {
770770
outfile << "\n";
771-
determinizeUtil(s2, visited, finalStates, errMsg, outfile);
771+
determinizeUtil(*s2, visited, finalStates, errMsg, outfile);
772772
}
773773
}
774774
else

0 commit comments

Comments
 (0)