Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion openfl/experimental/workflow/notebooktools/code_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import ast
import inspect
import re
import shutil
import sys
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import nbformat
from nbdev.export import nb_export
Expand Down Expand Up @@ -45,6 +46,8 @@ def __init__(self, notebook_path: Path, output_path: Path) -> None:
)
).resolve()
self.requirements = self._get_requirements()
user_imports = self.__extract_user_defined_imports(notebook_path)
self.__copy_user_defined_modules(user_imports, notebook_path)
self.__modify_experiment_script()

def __get_exp_name(self, notebook_path: Path) -> str:
Expand Down Expand Up @@ -86,6 +89,58 @@ def __convert_to_python(self, notebook_path: Path, output_path: Path, export_fil

return Path(output_path).joinpath(export_filename).resolve()

def __extract_user_defined_imports(self, notebook_path: Path) -> List[str]:
"""
Extract user-defined module imports from the notebook script,
excluding standard library and third-party modules.

Args:
notebook_path (Path): Path to the Jupyter notebook.

Returns:
List[str]: A list of user-defined module names used in the notebook.
"""
with open(self.script_path, "r") as file:
code = "".join(line for line in file if not line.lstrip().startswith(("!", "%")))

tree = ast.parse(code)
user_imports = set()

for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split(".")[0]
if self._is_user_defined_module(module_name, notebook_path):
user_imports.add(module_name)

elif isinstance(node, ast.ImportFrom) and node.module and node.level == 0:
module_name = node.module.split(".")[0]
if self._is_user_defined_module(module_name, notebook_path):
user_imports.add(module_name)

return list(user_imports)

def __copy_user_defined_modules(self, module_names: List[str], notebook_path: Path) -> None:
"""
Copies user-defined modules/packages to the generated workspace's src directory

Args:
module_names (List[str]): A list of user-defined module names
notebook_path (Path): Path to Jupyter notebook.
"""
src_dir = self.script_path.parent
for module_name in module_names:
try:
module_path, module_dir = self._get_module_paths(module_name, notebook_path)
if module_path.exists() and module_path.is_file():
shutil.copy(module_path, src_dir)
print(f"Copied user-defined module: {module_name}.py")
elif module_dir.exists() and module_dir.is_dir():
shutil.copytree(module_dir, src_dir / module_name, dirs_exist_ok=True)
print(f"Copied user-defined directory: {module_name}/")
except Exception as e:
print(f"[WARNING] Failed to copy '{module_name}':{e}")

def __modify_experiment_script(self) -> None:
"""Modifies the given python script by commenting out following code:
- occurences of flflow.run()
Expand Down Expand Up @@ -294,6 +349,42 @@ def _clean_value(self, value: str) -> str:
value = value.lstrip("[").rstrip("]")
return value

def _is_user_defined_module(self, module_name: str, notebook_path: Path) -> bool:
"""
Determine whether a given module is user-defined.

Args:
module_name (str): Name of the module.
notebook_path (Path): Path to Jupyter notebook using the module.

Return:
bool: True if the module is user-defined, False otherwise.
"""
# Reject empty or non-string module names
if not isinstance(module_name, str) or not module_name.strip():
return False

# Expected file path or directory path of the module
module_path, module_dir = self._get_module_paths(module_name, notebook_path)

return (module_path.exists() and module_path.is_file()) or module_dir.exists()

def _get_module_paths(self, module_name: str, notebook_path: Path) -> Tuple[Path, Path]:
"""
Get the file and directory paths for a user-defined module

Args:
module_name (str): Name of the module.
notebook_path (Path): Path to the Jupyter notebook.

Returns:
Tuple[Path, Path]: (module_file_path, module_directory_path)
"""
notebook_dir = notebook_path.parent
module_path = notebook_dir / f"{module_name}.py"
module_dir = notebook_dir / module_name
return module_path, module_dir

def _get_requirements(self) -> List[str]:
"""Extract pip libraries from the script

Expand Down
7 changes: 6 additions & 1 deletion openfl/utilities/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,19 @@ def __enter__(self):
os.chdir(self.experiment_work_dir)

# This is needed for python module finder
sys.path.append(str(self.experiment_work_dir))
for path in [self.experiment_work_dir, self.experiment_work_dir / "src"]:
path_str = str(path)
if path_str not in sys.path:
sys.path.append(path_str)

def __exit__(self, exc_type, exc_value, traceback):
"""Remove the workspace."""
os.chdir(self.cwd)
shutil.rmtree(self.experiment_work_dir, ignore_errors=True)
if str(self.experiment_work_dir) in sys.path:
sys.path.remove(str(self.experiment_work_dir))
if str(self.experiment_work_dir / "src") in sys.path:
sys.path.remove(str(self.experiment_work_dir / "src"))

if self.remove_archive:
logger.debug(
Expand Down
Loading