Skip to content
102 changes: 102 additions & 0 deletions cli/medperf/comms/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def __req(self, url, req_func, **kwargs):
"remember to provide the server certificate through --certificate"
)

def __get_count(self, url, filters={}, error_msg="") -> int:
filters = dict(filters)
filters.update({"is_valid": True, "limit": 1, "offset": 0})

query_str = "&".join([f"{k}={v}" for k, v in filters.items()])
paginated_url = f"{url}?{query_str}"

res = self.__auth_get(paginated_url)
if res.status_code != 200:
log_response_error(res)
details = format_errors_dict(res.json())
raise CommunicationRetrievalError(f"{error_msg}: {details}")

return res.json()["count"]

def __get_list(
self,
url,
Expand All @@ -95,6 +110,15 @@ def __get_list(
Returns:
List[dict]: A list of dictionaries representing the retrieved elements.
"""

filters = dict(filters)
if filters.get("limit", None) is not None:
page_size = filters["limit"]
num_elements = filters["limit"]

if filters.get("offset", None) is not None:
offset = filters["offset"]

el_list = []
filters.update({"is_valid": True})
if num_elements is None:
Expand Down Expand Up @@ -1143,3 +1167,81 @@ def get_dataset_benchmarks_associations(
url = f"{self.server_url}/datasets/{dataset_uid}/benchmarks/"
error_msg = "Could not get dataset benchmarks associations"
return self.__get_list(url, filters=filters, error_msg=error_msg)

def get_benchmarks_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of benchmarks in the platform.

Returns:
int: count of all benchmarks.
"""
if is_owner:
url = f"{self.server_url}/me/benchmarks/"
else:
url = f"{self.server_url}/benchmarks/"
error_msg = "Could not retrieve benchmarks count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_cubes_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of MLCubes in the platform.

Returns:
int: count of all MLCubes.
"""
if is_owner:
url = f"{self.server_url}/me/mlcubes/"
else:
url = f"{self.server_url}/mlcubes/"
error_msg = "Could not retrieve mlcubes count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_datasets_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of datasets in the platform.

Returns:
int: count of all datasets.
"""
if is_owner:
url = f"{self.server_url}/me/datasets/"
else:
url = f"{self.server_url}/datasets/"
error_msg = "Could not retrieve datasets count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_models_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of models in the platform.

Returns:
int: count of all models.
"""
if is_owner:
url = f"{self.server_url}/me/models/"
else:
url = f"{self.server_url}/models/"
error_msg = "Could not retrieve models count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_aggregators_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of aggregators in the platform.

Returns:
int: count of all aggregators.
"""
if is_owner:
url = f"{self.server_url}/me/aggregators/"
else:
url = f"{self.server_url}/aggregators/"
error_msg = "Could not retrieve aggregators count"
return self.__get_count(url, filters=filters, error_msg=error_msg)

def get_experiments_count(self, filters=dict(), is_owner=False) -> int:
"""Retrieves the count of training experiments in the platform.

Returns:
int: count of all training experiments.
"""
if is_owner:
url = f"{self.server_url}/me/training/"
else:
url = f"{self.server_url}/training/"
error_msg = "Could not retrieve training experiments count"
return self.__get_count(url, filters=filters, error_msg=error_msg)
4 changes: 4 additions & 0 deletions cli/medperf/entities/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_metadata_filename():
def get_comms_uploader():
return config.comms.upload_aggregator

@staticmethod
def get_comms_counter():
return config.comms.get_aggregators_count

@handle_validation_error
def __init__(self, **kwargs):
self._model = AggregatorSchema(**kwargs)
Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/entities/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def get_metadata_filename():
def get_comms_uploader():
return config.comms.upload_benchmark

@staticmethod
def get_comms_counter():
return config.comms.get_benchmarks_count

@handle_validation_error
def __init__(self, **kwargs):
"""Creates a new benchmark instance
Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/entities/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def get_metadata_filename():
def get_comms_uploader():
return config.comms.upload_mlcube

@staticmethod
def get_comms_counter():
return config.comms.get_cubes_count

@handle_validation_error
def __init__(self, **kwargs):
"""Creates a Cube instance
Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/entities/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_metadata_filename():
def get_comms_uploader():
return config.comms.upload_dataset

@staticmethod
def get_comms_counter():
return config.comms.get_datasets_count

@handle_validation_error
def __init__(self, **kwargs):
self._model = DatasetSchema(**kwargs)
Expand Down
18 changes: 18 additions & 0 deletions cli/medperf/entities/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from medperf.exceptions import MedperfException, InvalidArgumentError
from medperf.entities.schemas import MedperfSchema
from typing import Type, TypeVar
from medperf.account_management import get_medperf_user_data

EntityType = TypeVar("EntityType", bound="Entity")

Expand Down Expand Up @@ -53,6 +54,10 @@ def get_metadata_filename() -> str:
def get_comms_uploader() -> Callable[[dict], dict]:
raise NotImplementedError()

@staticmethod
def get_comms_counter() -> Callable[[dict, bool], int]:
raise NotImplementedError()

@property
def local_id(self) -> str:
raise NotImplementedError()
Expand Down Expand Up @@ -253,3 +258,16 @@ def display_dict(self) -> dict:
dict: the display dictionary
"""
raise NotImplementedError

@classmethod
def get_count(cls: Type[EntityType], filters: dict = {}) -> int:
"""Returns the count of items in the entity
Returns:
int: count of items
"""
logging.info(f"Retrieving the count of {cls.get_type()} entities")
user_data = get_medperf_user_data()
is_owner = "owner" in filters and filters["owner"] == user_data["id"]
comms_fn = cls.get_comms_counter()
count = comms_fn(filters=filters, is_owner=is_owner)
return count
4 changes: 4 additions & 0 deletions cli/medperf/entities/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def get_metadata_filename():
def get_comms_uploader():
return config.comms.upload_model

@staticmethod
def get_comms_counter():
return config.comms.get_models_count

@handle_validation_error
def __init__(self, **kwargs):
self._model = ModelSchema(**kwargs)
Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/entities/training_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def get_metadata_filename():
def get_comms_uploader():
return config.comms.upload_training_exp

@staticmethod
def get_comms_counter():
return config.comms.get_experiments_count

@handle_validation_error
def __init__(self, **kwargs):
"""Creates a new training_exp instance
Expand Down
30 changes: 25 additions & 5 deletions cli/medperf/web_ui/aggregators/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
templates,
)
from medperf.enums import CryptoKeyType
from medperf.web_ui.utils import build_listing_filters, build_pagination_context

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,22 +89,41 @@ def register_aggregator(
def aggregators_ui(
request: Request,
mine_only: bool = False,
page: int = 1,
page_size: int = 9,
ordering: str = "created_at_desc",
current_user: bool = Depends(check_user_ui),
):
filters = {}
my_user_id = get_medperf_user_data()["id"]

if mine_only:
filters["owner"] = my_user_id

total_count = Aggregator.get_count(filters=filters)

filters.update(
build_listing_filters(page=page, page_size=page_size, ordering=ordering)
)

aggregators = Aggregator.all(filters=filters)
aggregators = sorted(aggregators, key=lambda x: x.created_at or "", reverse=True)
mine_aggs = [a for a in aggregators if a.owner == my_user_id]
other_aggs = [a for a in aggregators if a.owner != my_user_id]
aggregators = mine_aggs + other_aggs

pagination_context = build_pagination_context(
page=page,
page_size=page_size,
ordering=ordering,
total_count=total_count,
page_items_count=len(aggregators),
)

return templates.TemplateResponse(
"aggregators/aggregators.html",
{"request": request, "aggregators": aggregators, "mine_only": mine_only},
{
"request": request,
"aggregators": aggregators,
"mine_only": mine_only,
**pagination_context,
},
)


Expand Down
39 changes: 30 additions & 9 deletions cli/medperf/web_ui/benchmarks/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
sort_associations_display,
check_user_ui,
)
from medperf.web_ui.utils import get_container_type
from medperf.web_ui.utils import (
get_container_type,
build_listing_filters,
build_pagination_context,
)

from medperf.commands.association.approval import Approval
from medperf.enums import Status
Expand All @@ -38,25 +42,42 @@
def benchmarks_ui(
request: Request,
mine_only: bool = False,
page: int = 1,
page_size: int = 9,
ordering: str = "created_at_desc",
current_user: bool = Depends(check_user_ui),
):

filters = {}
my_user_id = get_medperf_user_data()["id"]

if mine_only:
filters["owner"] = my_user_id

benchmarks = Benchmark.all(
filters=filters,
total_count = Benchmark.get_count(filters=filters)

filters.update(
build_listing_filters(page=page, page_size=page_size, ordering=ordering)
)

benchmarks = Benchmark.all(filters=filters)

pagination_context = build_pagination_context(
page=page,
page_size=page_size,
ordering=ordering,
total_count=total_count,
page_items_count=len(benchmarks),
)

benchmarks = sorted(benchmarks, key=lambda x: x.created_at, reverse=True)
# sort by (mine recent) (mine oldish), (other recent), (other oldish)
mine_benchmarks = [d for d in benchmarks if d.owner == my_user_id]
other_benchmarks = [d for d in benchmarks if d.owner != my_user_id]
benchmarks = mine_benchmarks + other_benchmarks
return templates.TemplateResponse(
"benchmark/benchmarks.html",
{"request": request, "benchmarks": benchmarks, "mine_only": mine_only},
{
"request": request,
"benchmarks": benchmarks,
"mine_only": mine_only,
**pagination_context,
},
)


Expand Down
34 changes: 26 additions & 8 deletions cli/medperf/web_ui/containers/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
check_user_ui,
sanitize_redirect_url,
)
from medperf.web_ui.utils import build_listing_filters, build_pagination_context

router = APIRouter()
logger = logging.getLogger(__name__)
Expand All @@ -33,24 +34,41 @@
def containers_ui(
request: Request,
mine_only: bool = False,
page: int = 1,
page_size: int = 9,
ordering: str = "created_at_desc",
current_user: bool = Depends(check_user_ui),
):
filters = {}
my_user_id = get_medperf_user_data()["id"]

if mine_only:
filters["owner"] = my_user_id

containers = Cube.all(
filters=filters,
total_count = Cube.get_count(filters=filters)

filters.update(
build_listing_filters(page=page, page_size=page_size, ordering=ordering)
)

containers = Cube.all(filters=filters)

pagination_context = build_pagination_context(
page=page,
page_size=page_size,
ordering=ordering,
total_count=total_count,
page_items_count=len(containers),
)
containers = sorted(containers, key=lambda x: x.created_at, reverse=True)
# sort by (mine recent) (mine oldish), (other recent), (other oldish)
my_containers = [c for c in containers if c.owner == my_user_id]
other_containers = [c for c in containers if c.owner != my_user_id]
containers = my_containers + other_containers

return templates.TemplateResponse(
"container/containers.html",
{"request": request, "containers": containers, "mine_only": mine_only},
{
"request": request,
"containers": containers,
"mine_only": mine_only,
**pagination_context,
},
)


Expand Down
Loading
Loading