Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
persist-credentials: false

- name: Set up Python 3.11
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.11

Expand Down
21 changes: 21 additions & 0 deletions mapc_sim/experimental/antenas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import chex
import jax.numpy as jnp


def isotropic(pos: chex.Array) -> chex.Array:
"""Calculate gain matrix for each point to point communication.

Parameters
----------
pos: Array
An array of 2d positions for `n` points

Returns
-------
Array
Antena gain for each point in pos.

"""
npoints = pos.shape[0]

return jnp.zeros((npoints, npoints))
20 changes: 20 additions & 0 deletions mapc_sim/experimental/walls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import chex
import jax.numpy as jnp


def free_space(pos: chex.Array) -> chex.Array:
"""Calculate free space loss

Parameters
----------
pos: Array
An array of 2d positions for `n` points

Returns
-------
Transmission loss for every pair of points in `pos`.

"""
npoints = pos.shape[0]

return jnp.zeros((npoints, npoints))
10 changes: 5 additions & 5 deletions mapc_sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def network_data_rate(key: PRNGKey, tx: Array, pos: Array, mcs: Array, tx_power: Array, sigma: Scalar,
walls: Array, return_sample: bool = False) -> Union[Scalar,tuple]:
loss_gain: Array, return_sample: bool = False) -> Union[Scalar,tuple]:
r"""
Calculates the aggregated effective data rate based on the nodes' positions, MCS, and tx power.
Channel is modeled using TGax channel model with additive white Gaussian noise. Effective
Expand Down Expand Up @@ -41,9 +41,9 @@ def network_data_rate(key: PRNGKey, tx: Array, pos: Array, mcs: Array, tx_power:
Transmission power of the nodes. Each entry corresponds to the transmission power of the transmitting node.
sigma: Scalar
Standard deviation of the additive white Gaussian noise.
walls: Array
Adjacency matrix of walls. If node i is separated from node j by a wall,
then `walls[i, j] = 1`, otherwise `walls[i, j] = 0`.
loss_gain: Array
Adjacency matrix of combined loss and gain. If node `i` is separated from node `j` by a wall or the antena is directional,
then `loss_gain[i, j] = val`, otherwise `walls[i, j] = 0`.
return_sample: bool
A flag indicating whether the simulator returns raw number of transmitted frames.

Expand All @@ -59,7 +59,7 @@ def network_data_rate(key: PRNGKey, tx: Array, pos: Array, mcs: Array, tx_power:
distance = jnp.sqrt(jnp.sum((pos[:, None, :] - pos[None, ...]) ** 2, axis=-1))
distance = jnp.clip(distance, REFERENCE_DISTANCE, None)

signal_power = tx_power[:, None] - tgax_path_loss(distance, walls)
signal_power = tx_power[:, None] - tgax_path_loss(distance, loss_gain)

interference_matrix = jnp.ones_like(tx) * tx.sum(axis=0) * tx.sum(axis=1, keepdims=True) * (1 - tx)
a = jnp.concatenate([signal_power, jnp.full((1, signal_power.shape[1]), fill_value=NOISE_FLOOR)], axis=0)
Expand Down
7 changes: 3 additions & 4 deletions mapc_sim/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import jax
import jax.numpy as jnp
from chex import Array

from mapc_sim.constants import *


def tgax_path_loss(distance: Array, walls: Array) -> Array:
def tgax_path_loss(distance: Array, loss_gain: Array) -> Array:
r"""
Calculates the path loss according to the TGax channel model [1]_.

Parameters
----------
distance: Array
Distance between nodes
walls: Array
loss_gain: Array
Adjacency matrix describing walls between nodes (1 if there is a wall, 0 otherwise).

Returns
Expand All @@ -27,7 +26,7 @@ def tgax_path_loss(distance: Array, walls: Array) -> Array:
"""

return (40.05 + 20 * jnp.log10((jnp.minimum(distance, BREAKING_POINT) * CENTRAL_FREQUENCY) / 2.4) +
(distance > BREAKING_POINT) * 35 * jnp.log10(distance / BREAKING_POINT) + WALL_LOSS * walls)
(distance > BREAKING_POINT) * 35 * jnp.log10(distance / BREAKING_POINT) + loss_gain)


def logsumexp_db(a: Array, b: Array) -> Array:
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ authors = [

requires-python = ">=3.9"
dependencies = [
"chex~=0.1.85",
"jax~=0.4.23",
"jaxlib~=0.4.23",
"matplotlib~=3.8.2",
"tensorflow-probability[jax]~=0.23.0"
"chex~=0.1.86",
"jax~=0.4.30",
"jaxlib~=0.4.30",
"matplotlib~=3.9.1",
"tensorflow-probability[jax]~=0.24.0"
]

[build-system]
Expand Down