Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions hipify_torch/hipify_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def asdict(self):
return {"hipified_path" : self.hipified_path, "status" : self.status}

HipifyFinalResult = Dict[str, HipifyResult]
HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
PYTHON_STYLE_BREADCRUMB = "# !!! This is a file automatically generated by hipify!!!\n"
C_STYLE_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
C_HEADER_EXTENSIONS=(".cuh", ".h", ".hpp")
C_FILE_EXTENSIONS=(".cu", ".c", ".cc", ".cpp", *C_HEADER_EXTENSIONS)
CYTHON_FILE_EXTENSIONS=(".pyx", ".pyd", ".pxi")
HIPIFY_BREADCRUMB_MAP = {"default": C_STYLE_BREADCRUMB}
HIPIFY_BREADCRUMB_MAP.update({ext: PYTHON_STYLE_BREADCRUMB for ext in CYTHON_FILE_EXTENSIONS})
HIPIFY_FINAL_RESULT: HipifyFinalResult = {}

# Hardcode the PyTorch template map
Expand Down Expand Up @@ -134,6 +140,9 @@ def __exit__(self, type, value, traceback):
for d in self.dirs_to_clean[::-1]:
os.rmdir(d)

def get_extension(filename: str) -> str:
return os.path.splitext(filename)[1]

def match_extensions(filename: str, extensions: Iterable) -> bool:
"""Helper method to see if filename ends with certain extension"""
return any(filename.endswith(e) for e in extensions)
Expand Down Expand Up @@ -803,8 +812,11 @@ def preprocessor(

rel_filepath = os.path.relpath(filepath, output_directory)

# find the breadcrumb matching file extension
hipify_breadcrumb = HIPIFY_BREADCRUMB_MAP.get(get_extension(fin_path), HIPIFY_BREADCRUMB_MAP["default"])
with open(fin_path, 'r', encoding='utf-8') as fin:
if fin.readline() == HIPIFY_C_BREADCRUMB:
# and check if the line contains the breadcrumb
if fin.readline() == hipify_breadcrumb:
hipify_result.hipified_path = None
hipify_result.status = "[ignored, input is hipified output]"
hipify_result.current_state = CurrentState.DONE
Expand Down Expand Up @@ -941,10 +953,10 @@ def repl(m):
hipify_result.current_state = CurrentState.DONE
return hipify_result

# Add hipify breadcrumb for C-style files to avoid re-hipification
if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
# Add hipify breadcrumb for supported file extensions to avoid re-hipification
if fin_path != fout_path and match_extensions(fin_path, (*C_FILE_EXTENSIONS, *CYTHON_FILE_EXTENSIONS)):
output_source_ascii=output_source.encode("ascii", "ignore").decode()
output_source = HIPIFY_C_BREADCRUMB + output_source_ascii
output_source = hipify_breadcrumb + output_source_ascii

do_write = True
if os.path.exists(fout_path):
Expand Down Expand Up @@ -1063,8 +1075,8 @@ def str2bool(v):
def hipify(
project_directory: str,
show_detailed: bool = False,
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
header_extensions: Iterable = (".cuh", ".h", ".hpp"),
extensions: Iterable = (*C_FILE_EXTENSIONS, ".in", *CYTHON_FILE_EXTENSIONS),
header_extensions: Iterable = C_HEADER_EXTENSIONS,
extra_extensions: Iterable = (),
output_directory: str = "",
header_include_dirs: Iterable = (),
Expand Down