Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions netforge_rl/agents/green_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class GreenAgent:

def __init__(self, agent_id: str = 'green_agent_0'):
self.agent_id = agent_id
from netforge_rl.siem.event_templates import evid_4624, sysmon_3, evid_4688

self._benign_templates = [evid_4624, sysmon_3, evid_4688]

def generate_noise(self, current_tick: int, global_state: Any) -> Dict[str, Any]:
"""Generates random telemetry alerts based on the current tick's position
Expand Down Expand Up @@ -41,33 +44,31 @@ def generate_noise(self, current_tick: int, global_state: Any) -> Dict[str, Any]
source = random.choice(hosts)
target = random.choice(hosts)
if source.ip != target.ip:
template = random.choice(self._benign_templates)
log_string = template(source.ip, target.ip)
noise_logs.append(
{
'type': 'benign_traffic',
'source': source.ip,
'target': target.ip,
'protocol': random.choice(['TCP', 'UDP', 'HTTP', 'DNS']),
'type': 'benign_xml',
'data': log_string,
'subnet': source.subnet_cidr,
'severity': 0,
}
)

if random.random() < probability_of_false_positive:
# Generate a false positive anomaly that could trip Blue's SIEM
target = random.choice(hosts)
from netforge_rl.siem.event_templates import evid_4625

log_string = evid_4625(
'unknown_external', target.ip, username='Administrator'
)
noise_logs.append(
{
'type': 'anomaly',
'source': 'unknown_external',
'target': target.ip,
'signature': random.choice(
[
'Failed_Login_Spike',
'Malformed_Packet',
'Suspicious_User_Agent',
]
),
'severity': random.randint(1, 4),
'false_positive': True,
'type': 'anomaly_xml',
'data': log_string,
'subnet': target.subnet_cidr,
'severity': 3,
}
)

Expand Down
20 changes: 4 additions & 16 deletions netforge_rl/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

class ActionEffect:
"""Encapsulates the resulting state changes from an action for conflict

resolution.
"""

Expand All @@ -18,18 +17,18 @@ def __init__(
state_deltas: Union[Dict[str, Any], List['IStateDeltaCommand']],
observation_data: Dict[str, Any],
eta: int = 0,
action: Optional['BaseAction'] = None,
):
self.success = success
self.state_deltas = state_deltas
self.observation_data = observation_data
self.eta = eta
self.action = action
self.cost = getattr(action, 'cost', 0) if action else 0


class BaseAction(ABC):
"""Modular Base Action for the MARL CybORG Environment.

All highly specific network attacks (Layer 2 - Layer 7) inherit from this class.
"""
"""Modular Base Action for the MARL CybORG Environment."""

def __init__(
self,
Expand All @@ -52,24 +51,18 @@ def __init__(
self.required_prior_state = required_prior_state

def validate(self, global_state: 'GlobalNetworkState') -> bool:
"""Checks if the action is physically possible in the current network
state (e.g., is there a route, are preconditions met).
"""
if self.target_ip and self.target_ip not in global_state.all_hosts:
return False

if self.required_prior_state:
# Check Action History state logic
agent_history = global_state.action_history.get(self.agent_id, set())
expected_record = f'{self.required_prior_state}:{self.target_ip}'
if expected_record not in agent_history:
return False

if self.target_ip:
host = global_state.all_hosts[self.target_ip]
# Simple declarative Zone constraints example
if 'red' in self.agent_id.lower() and host.subnet_cidr == '10.0.1.0/24':
# Secure Data targets cannot be touched without pivoting via DMZ or Internal User privileges first
has_dmz = any(
h.privilege in ['User', 'Root']
for h in global_state.all_hosts.values()
Expand All @@ -87,9 +80,4 @@ def validate(self, global_state: 'GlobalNetworkState') -> bool:

@abstractmethod
def execute(self, global_state: 'GlobalNetworkState') -> ActionEffect:
"""Computes the theoretical effect of the action.

Note: State is NOT mutated directly here. Mutations are returned via ActionEffect
to allow the Environment to resolve simultaneous multi-agent collisions.
"""
pass
5 changes: 4 additions & 1 deletion netforge_rl/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def to_numpy(self, max_size: int = 256) -> np.ndarray:
vector[idx] = val
idx += 1

for ip, data in self.visible_hosts.items():
# Sort the visible hosts by IP to ensure deterministic tensor mapping for RL models
sorted_ips = sorted(list(self.visible_hosts.keys()))
for ip in sorted_ips:
data = self.visible_hosts[ip]
if idx + 2 >= max_size:
break

Expand Down
4 changes: 2 additions & 2 deletions netforge_rl/core/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def resolve(effects: Dict[str, ActionEffect]) -> Dict[str, ActionEffect]:
blue_defended_nodes = {}
for blue_id in blue_agents:
eff = effects[blue_id]
if eff.success:
if eff is not None and eff.success:
if isinstance(eff.state_deltas, dict):
for delta_key in eff.state_deltas.keys():
if 'hosts/' in delta_key:
Expand All @@ -35,7 +35,7 @@ def resolve(effects: Dict[str, ActionEffect]) -> Dict[str, ActionEffect]:
# 2. Evaluate Red attacks against the compiled simultaneous defenses
for red_id in red_agents:
red_eff = effects[red_id]
if not red_eff.success:
if red_eff is None or not red_eff.success:
continue

collision_detected = False
Expand Down
34 changes: 17 additions & 17 deletions netforge_rl/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
class ActionRegistry:
"""A Factory Registry for dynamically tracking and instantiating
BaseAction subclasses without monolithic if/else blocks.

Adheres strictly to the Open-Closed Principle.
"""

def __init__(self):
# Maps (team, action_group_id) -> ActionClass
# Primary team mappings
self._actions: Dict[str, Dict[int, Type]] = {
'red': {},
'red_commander': {},
Expand All @@ -32,20 +30,25 @@ def decorator(cls):
def get_action_class(self, agent_id: str, group_id: int) -> Optional[Type]:
"""Retrieves the class constructor for a specific integer offset."""
if 'red' in agent_id.lower():
team = 'red_commander' if 'commander' in agent_id.lower() else 'red'
primary_team = 'red_commander' if 'commander' in agent_id.lower() else 'red'
else:
team = 'blue_commander' if 'commander' in agent_id.lower() else 'blue'
primary_team = (
'blue_commander' if 'commander' in agent_id.lower() else 'blue'
)

# Attempt to find the action in the primary team registry
action_cls = self._actions.get(primary_team, {}).get(group_id)

return self._actions.get(team, {}).get(group_id)
# Fallback: Check if the action was registered specifically to the role (e.g., 'red_operator')
if not action_cls:
action_cls = self._actions.get(agent_id.lower(), {}).get(group_id)

return action_cls

def instantiate_action(
self, agent_id: str, action_data: object, target_ips: list
) -> Optional[object]:
"""Factory method to resolve the generic action payload to an instance.

Supports legacy integer decoding or advanced Hierarchical MultiDiscrete
arrays: [action_type_id, target_ip_index].
"""
"""Factory method to resolve the generic action payload to an instance."""
if not target_ips:
target_ips = ['127.0.0.1']

Expand All @@ -64,30 +67,27 @@ def instantiate_action(
action_group = action_int // len(target_ips)

if 'red' in agent_id.lower():
mod = 4 if 'commander' in agent_id.lower() else 11
mod = 12 # Standardized bounds
else:
mod = 5 if 'commander' in agent_id.lower() else 7
mod = 12

action_type_id = action_group % mod

ActionCls = self.get_action_class(agent_id, action_type_id)
if not ActionCls:
return None

# Pass required kwargs dynamically based on the action archetype
# Determine accepted arguments dynamically
# Pass required kwargs dynamically
sig = inspect.signature(ActionCls.__init__)
params = sig.parameters

kwargs = {'agent_id': agent_id}
if 'target_ip' in params:
kwargs['target_ip'] = target_ip
elif 'target_subnet' in params:
# Approximate subnet from target_ip for actions requiring Subnets
parts = target_ip.split('.')
kwargs['target_subnet'] = f'{parts[0]}.{parts[1]}.{parts[2]}.0/24'
elif 'target_agent_id' in params:
# Map target_agent_id randomly or conventionally for Coordination actions
kwargs['target_agent_id'] = (
'red_operator' if agent_id == 'red_commander' else 'red_commander'
)
Expand Down
22 changes: 22 additions & 0 deletions netforge_rl/core/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from typing import Dict, Set, Any


Expand Down Expand Up @@ -208,6 +209,27 @@ def can_route_to(

return False

def get_adjacency_matrix(self) -> np.ndarray:
"""Returns a 100x100 adjacency matrix representing routing capabilities between all hosts."""
import numpy as np

adj = np.zeros((100, 100), dtype=np.float32)

# We need a stable ordering of IPs, so we sort them
sorted_ips = sorted(list(self.all_hosts.keys()))

for i, src_ip in enumerate(sorted_ips):
for j, dst_ip in enumerate(sorted_ips):
if i == j:
adj[i, j] = 1.0
elif self.can_route_to(dst_ip):
# Simplification: if it can route to dst, we mark an edge.
# A more accurate version would check if src_ip can route to dst_ip,
# but can_route_to doesn't take src_ip. It assumes global routing rules based on subnets.
adj[i, j] = 1.0

return adj

def reallocate_dhcp(self):
"""Simulates dynamic mid-episode restructuring of the network.

Expand Down
Loading
Loading