diff --git a/p2p/network.go b/p2p/network.go index 1409243..cd62498 100644 --- a/p2p/network.go +++ b/p2p/network.go @@ -85,13 +85,13 @@ func (n *Network) PeerIDs() []PeerID { } // Publish sends a message to the specified node's message queue. -func (n *Network) Publish(nodeID PeerID, msg string, protocol BroadcastProtocol) error { +func (n *Network) Publish(nodeID PeerID, msg string, protocol BroadcastProtocol, customProtocol func(msg Message, known []PeerID, sent []PeerID, received []PeerID, params map[string]any) *[]PeerID) error { if node, ok := n.nodes[nodeID]; ok { if !node.alive { return fmt.Errorf("node %d is not alive", nodeID) } - node.msgQueue <- Message{From: nodeID, Content: msg, Protocol: protocol, HopCount: 0} + node.msgQueue <- Message{From: nodeID, Content: msg, Protocol: protocol, HopCount: 0, CustomProtocol: customProtocol} return nil } diff --git a/p2p/node.go b/p2p/node.go index 6ad51c9..e311c5c 100644 --- a/p2p/node.go +++ b/p2p/node.go @@ -51,12 +51,12 @@ func (n *p2pNode) eachRun(network *Network, wg *sync.WaitGroup, ctx context.Cont n.alive = true wg.Done() - for msg := range n.msgQueue { - select { - case <-ctx.Done(): - n.alive = false - return - default: + select { + case <-ctx.Done(): + n.alive = false + return + default: + for msg := range n.msgQueue { first := false n.mu.Lock() @@ -117,7 +117,6 @@ func (n *p2pNode) publish(network *Network, msg Message) { if _, received := n.recvFrom[content][edge.targetID]; received { continue } - n.sentTo[content][edge.targetID] = struct{}{} willSendEdges = append(willSendEdges, edge) } @@ -128,25 +127,52 @@ func (n *p2pNode) publish(network *Network, msg Message) { }) k := int(float64(len(willSendEdges)) * network.cfg.GossipFactor) + willSendEdges = willSendEdges[:k] } } else if protocol == Custom { + allEdges := make([]PeerID, 0) + for _, edge := range n.edges { + allEdges = append(allEdges, edge.targetID) + } + sentEdges := make([]PeerID, 0) + for targetID := range n.sentTo[content] { + sentEdges = append(sentEdges, targetID) + } + + receivedEdges := make([]PeerID, 0) + for senderID := range n.recvFrom[content] { + receivedEdges = append(receivedEdges, senderID) + } + + targets := msg.CustomProtocol(msg, allEdges, sentEdges, receivedEdges, network.cfg.CustomParams) + + for _, targetID := range *targets { + for _, edge := range n.edges { + if edge.targetID == targetID { + willSendEdges = append(willSendEdges, edge) + break + } + } + } } else { return } for _, edge := range willSendEdges { edgeCopy := edge + n.sentTo[content][edge.targetID] = struct{}{} go func(e p2pEdge) { time.Sleep(time.Duration(e.edgeLatency) * time.Millisecond) network.nodes[e.targetID].msgQueue <- Message{ - From: n.id, - Content: content, - Protocol: protocol, - HopCount: hopCount + 1, + From: n.id, + Content: content, + Protocol: protocol, + HopCount: hopCount + 1, + CustomProtocol: msg.CustomProtocol, } }(edgeCopy) } diff --git a/p2p/p2p.go b/p2p/p2p.go index 55db4af..0402fe1 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -6,13 +6,15 @@ type PeerID uint64 // Message represents a message sent between nodes in the P2P network. type Message struct { - From PeerID - Content string - Protocol BroadcastProtocol - HopCount int + From PeerID + Content string + Protocol BroadcastProtocol + HopCount int + CustomProtocol func(Message, []PeerID, []PeerID, []PeerID, map[string]any) *[]PeerID } // Config holds configuration parameters for the P2P network. type Config struct { - GossipFactor float64 // fraction of neighbors to gossip to + GossipFactor float64 // fraction of neighbors to gossip to + CustomParams map[string]any // parameters for custom protocols } diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 328a3fb..004f484 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -37,7 +37,7 @@ func TestGenerateNetwork(t *testing.T) { nw.RunNetworkSimulation(ctx) t.Logf("Publishing message '%s' from node %d\n", msg1, nw.PeerIDs()[0]) - err = nw.Publish(nw.PeerIDs()[0], msg1, p2p.Flooding) + err = nw.Publish(nw.PeerIDs()[0], msg1, p2p.Flooding, nil) if err != nil { t.Fatalf("Failed to publish message: %v", err) } @@ -45,7 +45,7 @@ func TestGenerateNetwork(t *testing.T) { t.Logf("Reachability of message '%s': %f\n", msg1, nw.Reachability(msg1)) t.Logf("Publishing message '%s' from node %d\n", msg2, nw.PeerIDs()[1]) - err = nw.Publish(nw.PeerIDs()[1], msg2, p2p.Gossiping) + err = nw.Publish(nw.PeerIDs()[1], msg2, p2p.Gossiping, nil) if err != nil { t.Fatalf("Failed to publish message: %v", err) } @@ -56,7 +56,7 @@ func TestGenerateNetwork(t *testing.T) { nw.RunNetworkSimulation(context.Background()) t.Logf("Publishing message '%s' from node %d\n", msg3, nw.PeerIDs()[2]) - err = nw.Publish(nw.PeerIDs()[2], msg3, p2p.Gossiping) + err = nw.Publish(nw.PeerIDs()[2], msg3, p2p.Gossiping, nil) if err != nil { t.Fatalf("Failed to publish message: %v", err) } @@ -104,7 +104,7 @@ func TestMetrics(t *testing.T) { nw.RunNetworkSimulation(ctx) t.Logf("Publishing message '%s' from node %d\n", msg1, nw.PeerIDs()[0]) - err = nw.Publish(nw.PeerIDs()[0], msg1, p2p.Flooding) + err = nw.Publish(nw.PeerIDs()[0], msg1, p2p.Flooding, nil) if err != nil { t.Fatalf("Failed to publish message: %v", err) }