diff --git a/pyproject.toml b/pyproject.toml index fef8f7e1..b990a358 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ rich = "^12.6.0" dataclasses-json = "^0.5.6" pandas = "^1.5.1" appdirs = "^1.4.4" +deepdiff = "^6.3.1" [tool.poetry.dev-dependencies] pytest = "^6.2.5" diff --git a/src/aerie_cli/aerie_client.py b/src/aerie_cli/aerie_client.py index b78ac281..59ec29e3 100644 --- a/src/aerie_cli/aerie_client.py +++ b/src/aerie_cli/aerie_client.py @@ -20,6 +20,7 @@ from .schemas.client import ExpansionSet from .schemas.client import ResourceType from .utils.serialization import postgres_duration_to_microseconds +from .utils.serialization import timedelta_to_postgres_interval from .aerie_host import AerieHostSession @@ -78,6 +79,64 @@ def get_activity_plan_by_id(self, plan_id: int, full_args: str = None) -> Activi plan = ActivityPlanRead.from_api_read(api_plan) return self.__expand_activity_arguments(plan, full_args) + def get_activity_plan_subset_by_id( + self, plan_id: int, + start_time: arrow.Arrow, + end_time: arrow.Arrow, + full_args: str = None) -> ActivityPlanRead: + """Download a subset of an activity plan from Aerie + + Args: + plan_id (int): ID of the plan in Aerie + full_args (str): comma separated list of activity types for which to + get full arguments, otherwise only modified arguments are returned. + Set to "true" to get full arguments for all activity types. + Disabled if missing, None, "false", or "". + start_time (arrow.Arrow) + end_time (arrow.Arrow) + + Returns: + ActivityPlanRead: the activity plan subset + """ + + parent = self.get_activity_plan_by_id(plan_id) + parent_start = parent.start_time + start_offset = timedelta_to_postgres_interval(start_time - parent_start) + end_offset = timedelta_to_postgres_interval(end_time - parent_start) + + query = """ + query get_plans ($start_offset: interval!, $end_offset: interval!, $plan_id: Int!) { + plan_by_pk(id: $plan_id) { + id + model_id + name + start_time + duration + simulations{ + id + } + activity_directives(where: {start_offset: {_gte: $start_offset, _lt: $end_offset}}, order_by: { start_offset: asc }) { + id + name + type + start_offset + arguments + metadata + tags + } + } + } + """ + resp = self.host_session.post_to_graphql( + query, + plan_id=plan_id, + start_offset=start_offset, + end_offset=end_offset + ) + plan = ApiActivityPlanRead.from_dict(resp) + plan = ActivityPlanRead.from_api_read(plan) + return self.__expand_activity_arguments(plan, full_args) + def list_all_activity_plans(self) -> List[ActivityPlanRead]: list_all_plans_query = """ query list_all_plans { @@ -278,6 +337,21 @@ def update_activity( ) return resp["id"] + def delete_activity(self, activity_id: int, plan_id: int): + query = """ + mutation DeleteActivityDirective($id: Int!, $plan_id: Int!) { + delete_activity_directive_by_pk(id: $id, plan_id: $plan_id) { + name + } + } + """ + resp = self.host_session.post_to_graphql( + query, + id=activity_id, + plan_id=plan_id + ) + return resp + def simulate_plan(self, plan_id: int, poll_period: int = 5) -> int: simulate_query = """ @@ -1658,7 +1732,7 @@ def add_directive_metadata_schemas(self, schemas: list) -> list: ) return resp - def delete_directive_metadata_schema(self, key) -> list: + def delete_directive_metadata_schema(self, key: str) -> list: """Delete metadata schemas Returns: @@ -1676,4 +1750,108 @@ def delete_directive_metadata_schema(self, key) -> list: delete_schema_query, key=key ) - return resp["key"] \ No newline at end of file + return resp["key"] + + def get_activity_directive_preset(self, activity_id: int, plan_id: int) -> dict: + """Get the preset applied on a activity directive + + Returns: + dict: an object representing the activity directive. + dict["applied_preset"] == None if there is no preset applied to the activity + """ + query = """ + query MyQuery($id: Int!, $plan_id: Int!) { + activity_directive_by_pk(id: $id, plan_id: $plan_id) { + applied_preset { + presets_applied { + name + } + preset_id + } + } + } + """ + resp = self.host_session.post_to_graphql( + query, + id=activity_id, + plan_id=plan_id + ) + return resp + + def apply_activity_directive_preset(self, activity_id, plan_id, preset_id) -> int: + """Apply the a preset to an activity directive + + Returns: + int: the ID of the applied preset + """ + mutation = """ + mutation MyMutation($_activity_id: Int!, $_plan_id: Int!, $_preset_id: Int!) { + apply_preset_to_activity(args: {_activity_id: $_activity_id, _plan_id: $_plan_id, _preset_id: $_preset_id}) { + id + } + } + """ + resp = self.host_session.post_to_graphql( + mutation, + _activity_id=activity_id, + _plan_id=plan_id, + _preset_id=preset_id + ) + return resp["id"] + + def delete_activity_directive_preset(self, activity_id, plan_id) -> dict: + """Remove a preset from an activity directive + + Returns: + dict: an object representing the preset that was deleted + """ + mutation = """ + mutation MyMutation($_activity_id: Int!, $_plan_id: Int!) { + delete_preset_to_directive(where: {activity_id: {_eq: $_activity_id}, _and: {plan_id: {_eq: $_plan_id}}}) { + returning { + preset_id + } + } + } + """ + resp = self.host_session.post_to_graphql( + mutation, + _activity_id=activity_id, + _plan_id=plan_id + ) + return resp + + # TODO: the two functions below are very specific to the plans merge command + # do we still want them here or is there another way to integrate them into the cli? + + def get_plan_created_date(self, plan_id: int): + # for some reason, plan doesnt have a `created_at` field on local, but its there + # in the dev venue. i need to look at the activity `created_at` for now..... + plan_query = """ + query MyQuery($plan_id: Int!) { + activity_directive(where: {plan_id: {_eq: $plan_id}}) { + created_at + } + } + """ + resp = self.host_session.post_to_graphql( + plan_query, + plan_id=plan_id, + ) + return arrow.get(resp[0]["created_at"]) if len(resp) > 0 else None + + def get_plan_recently_updated_activities(self, plan_id: int, time: arrow.Arrow): + activity_query = """ + query GetRecentlyUpdatedActivities($plan_id: Int!, $time: timestamptz!) { + activity_directive(where: {_and: {plan_id: {_eq: $plan_id}}, last_modified_at: {_gt: $time}}) { + id + name + } + } + """ + resp = self.host_session.post_to_graphql( + activity_query, + plan_id=plan_id, + time=str(time) + ) + return resp \ No newline at end of file diff --git a/src/aerie_cli/commands/plans.py b/src/aerie_cli/commands/plans.py index 1d5db0c3..88b6c6f8 100644 --- a/src/aerie_cli/commands/plans.py +++ b/src/aerie_cli/commands/plans.py @@ -6,6 +6,7 @@ import typer from rich.console import Console from rich.table import Table +from deepdiff import DeepDiff from aerie_cli.commands.command_context import CommandContext from aerie_cli.schemas.client import ActivityPlanCreate @@ -315,3 +316,178 @@ def clean(): typer.echo(f"All activity plans have been deleted") + +@app.command() +def subset( + plan_id: int = typer.Option(..., "--plan-id", "-p", help="Plan ID to subset from", prompt=True), + name: str = typer.Option(..., "--name", "-n", help="Name of the child plan", prompt=True), + start_time: str = typer.Option(..., "--start", "-s", help="Start time of child plan", prompt=True), + end_time: str = typer.Option(..., "--end", "-e", help="End time of child plan", prompt=True) +): + """ + Branch off a part of a plan, given start and end times. + + Times are given in the format YYYY-DDDTHH:mm:ss.SSS and must be within the start and + end times of the parent plan. + + This command will also add `parent_activity_id` and `parent_plan_id` metadata to the + child activity directives so they can reference back to the parent plan. To view these in the + UI, please add the following directive metadata schema to the Aerie instance. + + parent_activity_id: INTEGER + parent_plan_id: INTEGER + """ + client = CommandContext.get_client() + + # get parent plan info + parent_id = plan_id + start_time = arrow.get(start_time) + end_time = arrow.get(end_time) + parent = client.get_activity_plan_subset_by_id( + parent_id, + start_time, + end_time + ) + parent_name = parent.name + + # modify child plan + model_id = parent.model_id + parent.name = name + parent.start_time = start_time + parent.end_time = end_time + + # add metadata + for a in parent.activities: + a.metadata["parent_activity_id"] = a.id + a.metadata["parent_plan_id"] = parent_id + + # create child plan + parent = ActivityPlanCreate.from_plan_read(parent) + child_id = client.create_activity_plan(model_id, parent) + child_data = client.get_activity_plan_by_id(child_id) + + # apply presets + for activity in child_data.activities: + preset = client.get_activity_directive_preset(activity.metadata["parent_activity_id"], parent_id) + if preset["applied_preset"]: + client.apply_activity_directive_preset(activity.id, child_id, preset["applied_preset"]["preset_id"]) + + typer.echo(f"Created branch of `{parent_name}` (id {parent_id}) with name `{child_data.name}` (id {child_id}).") + + +@app.command() +def merge( + child_id: int = typer.Option(..., "--child-id", "-c", help="Plan ID of child", prompt=True), + parent_id: int = typer.Option(..., "--parent-id", "-p", help="Plan ID of parent to merge into", prompt=True), +): + """ + Merge a child plan into a parent plan. Supports adding, updating, and deleting + activities. + + This command best works with a child created from the `subset` command merging back + into a parent, since the child contains metadata about which parent activity it + came from. All activities in the child plan must fall within the time range of + the parent plan. + + This command also has limited conflict checking by seeing if there were any changes + made to the parent after the child branch was created. + """ + client = CommandContext.get_client() + + # get and store data + child_data = client.get_activity_plan_by_id(child_id) + parent_data = client.get_activity_plan_subset_by_id(parent_id, child_data.start_time, child_data.end_time) + + child_activity_ids = [a.metadata["parent_activity_id"] for a in child_data.activities if "parent_activity_id" in a.metadata] + parent_activity_ids = [a.id for a in parent_data.activities] + + child_creation_time = client.get_plan_created_date(child_id) + + # check for possible conflicts + if child_creation_time: + conflicts = client.get_plan_recently_updated_activities(parent_id, child_creation_time) + if len(conflicts) > 0: + typer.echo("Warning: potential conflicts detected in the following activities") + for conflict in conflicts: + typer.echo(f"Activity name: {conflict['name']} (id {conflict['id']})") + + typer.echo("If the merge continues, activities in the parent plan will be overwritten, and new activities in the parent plan will not be deleted.") + proceed = typer.confirm("Are you sure you would like to continue merging?") + if not proceed: + typer.echo("Aborting merge") + raise typer.Abort() + typer.echo("Continuing merge") + else: + typer.echo("Warning: the child has no activities, which will delete all activities in the parent for this time frame.") + proceed = typer.confirm("Are you sure you would like to continue merging?") + if not proceed: + typer.echo("Aborting merge") + raise typer.Abort() + typer.echo("Continuing merge") + + # add/update activities in parent plan + for activity in child_data.activities: + + # child activity was not branched from parent + if ("parent_plan_id" in activity.metadata and + activity.metadata["parent_plan_id"] != parent_id): + typer.echo("Warning: the plan you are trying to merge into does not match the plan that this child was pulled from.") + proceed = typer.confirm("Are you sure you would like to continue merging?") + if not proceed: + typer.echo("Aborting merge") + raise typer.Abort() + typer.echo("Continuing merge") + + # add new child activity + if "parent_activity_id" not in activity.metadata: + activity_id = client.create_activity(activity, parent_id, parent_data.start_time) + activity.metadata["parent_activity_id"] = activity_id + activity.metadata["parent_plan_id"] = parent_id + client.update_activity(activity.id, activity, child_id, child_data.start_time) + typer.echo(f"Added activity {activity.name} (id {activity.id} in child) to parent.") + + # or, update existing child activity + else: + # grab the metadata + parent_activity_id = activity.metadata.pop("parent_activity_id") + parent_plan_id = activity.metadata.pop("parent_plan_id") + + # find the corresponding parent activity + if parent_activity_id in parent_activity_ids: + parent_activity = [a for a in parent_data.activities if a.id == parent_activity_id][0] + difference = DeepDiff(parent_activity, activity) + + # only update the activity if there are any changed values + if ("values_changed" in difference and + (len(difference["values_changed"]) > 1 or len(difference) > 1)): # the activity IDs will be different, so we wanna detect any diff besides that + client.update_activity(parent_activity_id, activity, parent_data.id, parent_data.start_time) + typer.echo(f"Updated activity {activity.name} (id {activity.id}) in parent plan.") + typer.echo(f"Difference: \n{difference.pretty()}") + + # restore metadata in child activity + activity.metadata["parent_activity_id"] = parent_activity_id + activity.metadata["parent_plan_id"] = parent_plan_id + + # case where child activity cannot find its parent activity in the plan anymore + else: + typer.echo(f"Warning: activity (id {parent_activity_id}) was deleted in parent after branching. Adding activity back to parent.") + activity_id = client.create_activity(activity, parent_id, parent_data.start_time) + activity.metadata["parent_activity_id"] = activity_id + activity.metadata["parent_plan_id"] = parent_id + client.update_activity(activity.id, activity, child_id, child_data.start_time) + typer.echo(f"Added activity {activity.name} (id {activity.id} in child) to parent.") + + # apply preset, if there are any + preset = client.get_activity_directive_preset(activity.id, child_id) + if preset["applied_preset"]: + client.apply_activity_directive_preset(activity.metadata["parent_activity_id"], parent_id, preset["applied_preset"]["preset_id"]) + else: + client.delete_activity_directive_preset(activity.metadata["parent_activity_id"], parent_id) + + # delete deleted activities from parent + deleted_activities = set(parent_activity_ids) - set(child_activity_ids) + for activity_id in deleted_activities: + activity_name = client.delete_activity(activity_id, parent_id) + typer.echo(f"Deleted activity {activity_name} (id {activity_id}) in parent.") + + typer.echo(f"Finished merging plan {child_data.name} (id {child_id}) into {parent_data.name} (id {parent_id}).")