Skip to content

Conversation

@zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Nov 24, 2025

Description

As I requested, TensorRT 10.14 added an argument trt.SerializationFlag.INCLUDE_REFIT to allow refitted engines to keep refittable. That means engines can be refitted multiple times. Based on the capability, this PR enhances the existing engine caching and refitting features as follows:

  1. To save hard disk space, engine caching will only save weight-stripped engines on disk regardless of compilation_settings.strip_engine_weights. Then, when users pull out the cached engine, it will be automatically refitted and kept refittable.
  2. Compiled TRT modules can be refitted multiple times with refit_module_weights(). e.g.:
for _ in range(3):
    trt_gm = refit_module_weights(trt_gm, exp_program)
  1. Due to some changes, the insertion and pulling of cached engines are located in different places, which causes 🐛 [Bug] Engine cache failed on torch.compile backend=tensorrt #3909. This PR unified the insertion and pulling in _conversion.py.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Nov 24, 2025
@meta-cla meta-cla bot added the cla signed label Nov 24, 2025
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: torch_compile labels Nov 24, 2025
@narendasan
Copy link
Collaborator

@cehongwang please take a pass so we have multiple eyes on this PR

@zewenli98 zewenli98 force-pushed the improve_engine_caching branch from a54907e to ea81677 Compare December 4, 2025 18:38
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would a torch.compile use try to use strip weights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the warning back. Not sure why strip_engine_weights arg doesn't work for torch.compile()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just doenst make sense. Torch compile is not serializable. So why would you ever want a callable that doesnt have the weights in it

logger.info(f"The engine already exists in cache for hash: {hash_val}")
return False

if not settings.strip_engine_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like strip weights should only apply to the returned engine and not to the cache directly. So a returned cache engine with strip weights == True wont be refit. but you always only save stripped engine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design is to save stripped engine for using less hard disk. A returned cache engine with strip weights == True wont be refit as well. Only strip weights == False will be refit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but what i mean is every engine we save in the cache should be a stripped weights engine right? We arent doing weight matching in the hash function for pulling from the cache so we will need to refit anyway. So the serialization config should always have serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) regardless if the user tells us settings.strip_engine_weights == True or settings.strip_engine_weights == False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User Setting Engine Saved to Cache Engine Returned to User
settings.strip_engine_weights == True Weightless Weightless
settings.strip_engine_weights == False Weightless Refit Weights

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it that we dont need to explictly set the setting if the weight was built with stripped weights? Is there any harm in doing so? Then there is only one code path

@narendasan narendasan linked an issue Dec 9, 2025 that may be closed by this pull request
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"

logger.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the hash

and engine_cache is not None
and not settings.immutable_weights
):
if settings.cache_built_engines or settings.reuse_cached_engines:
Copy link
Collaborator

@narendasan narendasan Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code is unclear. Would recommend something like this

hash_val = engine_cache.get_hash(module, inputs, settings) if (settings.cache_built_engines or settings.reuse_cached_engines) else None

if settings.reuse_cached_engines:
    serialized_interpreter_result = pull_cached_engine(
        hash_val, module, engine_cache, settings, inputs
    )
    if serialized_interpreter_result is not None:  # hit the cache
        return serialized_interpreter_result

...

if (
    ENABLED_FEATURES.refit
    and not settings.immutable_weights
    and settings.cache_built_engines
    and engine_cache is not None
):
    _ = insert_engine_to_cache(
        hash_val, interpreter_result, engine_cache, settings, inputs
    )

    serialized_engine = interpreter_result.engine.serialize()
    

if (
ENABLED_FEATURES.refit
and not settings.immutable_weights
and settings.cache_built_engines
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably throw a warning or something if engine_cache is None and settings.cache_built_engines == True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests component: torch_compile

Projects

None yet

Development

Successfully merging this pull request may close these issues.

📖 [Story] Weightless Engine Building

3 participants