diff --git a/openfl/experimental/workflow/notebooktools/code_analyzer.py b/openfl/experimental/workflow/notebooktools/code_analyzer.py index 9fed1e7426..2fb15c844f 100644 --- a/openfl/experimental/workflow/notebooktools/code_analyzer.py +++ b/openfl/experimental/workflow/notebooktools/code_analyzer.py @@ -4,10 +4,11 @@ import ast import inspect import re +import shutil import sys from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import nbformat from nbdev.export import nb_export @@ -45,6 +46,8 @@ def __init__(self, notebook_path: Path, output_path: Path) -> None: ) ).resolve() self.requirements = self._get_requirements() + user_imports = self.__extract_user_defined_imports(notebook_path) + self.__copy_user_defined_modules(user_imports, notebook_path) self.__modify_experiment_script() def __get_exp_name(self, notebook_path: Path) -> str: @@ -86,6 +89,58 @@ def __convert_to_python(self, notebook_path: Path, output_path: Path, export_fil return Path(output_path).joinpath(export_filename).resolve() + def __extract_user_defined_imports(self, notebook_path: Path) -> List[str]: + """ + Extract user-defined module imports from the notebook script, + excluding standard library and third-party modules. + + Args: + notebook_path (Path): Path to the Jupyter notebook. + + Returns: + List[str]: A list of user-defined module names used in the notebook. + """ + with open(self.script_path, "r") as file: + code = "".join(line for line in file if not line.lstrip().startswith(("!", "%"))) + + tree = ast.parse(code) + user_imports = set() + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name.split(".")[0] + if self._is_user_defined_module(module_name, notebook_path): + user_imports.add(module_name) + + elif isinstance(node, ast.ImportFrom) and node.module and node.level == 0: + module_name = node.module.split(".")[0] + if self._is_user_defined_module(module_name, notebook_path): + user_imports.add(module_name) + + return list(user_imports) + + def __copy_user_defined_modules(self, module_names: List[str], notebook_path: Path) -> None: + """ + Copies user-defined modules/packages to the generated workspace's src directory + + Args: + module_names (List[str]): A list of user-defined module names + notebook_path (Path): Path to Jupyter notebook. + """ + src_dir = self.script_path.parent + for module_name in module_names: + try: + module_path, module_dir = self._get_module_paths(module_name, notebook_path) + if module_path.exists() and module_path.is_file(): + shutil.copy(module_path, src_dir) + print(f"Copied user-defined module: {module_name}.py") + elif module_dir.exists() and module_dir.is_dir(): + shutil.copytree(module_dir, src_dir / module_name, dirs_exist_ok=True) + print(f"Copied user-defined directory: {module_name}/") + except Exception as e: + print(f"[WARNING] Failed to copy '{module_name}':{e}") + def __modify_experiment_script(self) -> None: """Modifies the given python script by commenting out following code: - occurences of flflow.run() @@ -294,6 +349,42 @@ def _clean_value(self, value: str) -> str: value = value.lstrip("[").rstrip("]") return value + def _is_user_defined_module(self, module_name: str, notebook_path: Path) -> bool: + """ + Determine whether a given module is user-defined. + + Args: + module_name (str): Name of the module. + notebook_path (Path): Path to Jupyter notebook using the module. + + Return: + bool: True if the module is user-defined, False otherwise. + """ + # Reject empty or non-string module names + if not isinstance(module_name, str) or not module_name.strip(): + return False + + # Expected file path or directory path of the module + module_path, module_dir = self._get_module_paths(module_name, notebook_path) + + return (module_path.exists() and module_path.is_file()) or module_dir.exists() + + def _get_module_paths(self, module_name: str, notebook_path: Path) -> Tuple[Path, Path]: + """ + Get the file and directory paths for a user-defined module + + Args: + module_name (str): Name of the module. + notebook_path (Path): Path to the Jupyter notebook. + + Returns: + Tuple[Path, Path]: (module_file_path, module_directory_path) + """ + notebook_dir = notebook_path.parent + module_path = notebook_dir / f"{module_name}.py" + module_dir = notebook_dir / module_name + return module_path, module_dir + def _get_requirements(self) -> List[str]: """Extract pip libraries from the script diff --git a/openfl/utilities/workspace.py b/openfl/utilities/workspace.py index 15e7a3a339..6971bf09ba 100644 --- a/openfl/utilities/workspace.py +++ b/openfl/utilities/workspace.py @@ -105,7 +105,10 @@ def __enter__(self): os.chdir(self.experiment_work_dir) # This is needed for python module finder - sys.path.append(str(self.experiment_work_dir)) + for path in [self.experiment_work_dir, self.experiment_work_dir / "src"]: + path_str = str(path) + if path_str not in sys.path: + sys.path.append(path_str) def __exit__(self, exc_type, exc_value, traceback): """Remove the workspace.""" @@ -113,6 +116,8 @@ def __exit__(self, exc_type, exc_value, traceback): shutil.rmtree(self.experiment_work_dir, ignore_errors=True) if str(self.experiment_work_dir) in sys.path: sys.path.remove(str(self.experiment_work_dir)) + if str(self.experiment_work_dir / "src") in sys.path: + sys.path.remove(str(self.experiment_work_dir / "src")) if self.remove_archive: logger.debug(