Skip to content

Comments

WIP: Save/load models for CAAFE -- Client#184

Open
brendan-priorlabs wants to merge 1 commit intomainfrom
brendan/caafe-saved-models
Open

WIP: Save/load models for CAAFE -- Client#184
brendan-priorlabs wants to merge 1 commit intomainfrom
brendan/caafe-saved-models

Conversation

@brendan-priorlabs
Copy link
Collaborator

No description provided.

@brendan-priorlabs brendan-priorlabs requested a review from a team as a code owner November 19, 2025 09:18
@brendan-priorlabs brendan-priorlabs requested review from noahho and removed request for a team November 19, 2025 09:18
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @brendan-priorlabs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the client-side model management by introducing a 'model_id' for uniquely identifying trained models. This allows the client to explicitly reference and utilize specific models for prediction, moving beyond just dataset UIDs. The 'fit' and 'predict' interfaces have been updated to support this new identifier, improving the overall model lifecycle management within the client application. This is a work-in-progress to enable save/load functionality for CAAFE models.

Highlights

  • Model ID Introduction: A new 'model_id' field has been added to the 'FitCompleteEvent' and a new 'FittedModel' data structure to track specific models after training, allowing for more granular model management.
  • Updated 'fit' Method: The 'fit' method now returns a 'FittedModel' object, which encapsulates both the 'train_set_uid' and the newly introduced 'model_id', providing a more comprehensive result.
  • Enhanced 'predict' Method: The 'predict' method now accepts an optional 'model_id' parameter, enabling predictions to be made using a specific saved model, improving model selection capabilities.
  • Client-side Model Tracking: The 'TabPFNClassifier' and 'TabPFNRegressor' classes have been updated to store the 'model_id' alongside the 'train_set_uid', facilitating better internal model management within the client.
  • Local Testing Configuration: The 'server_config.yaml' has been temporarily adjusted to prioritize local testing configurations, indicating ongoing development and testing efforts.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the concept of a model_id to uniquely identify fitted models, which is a significant enhancement for saving and loading models. The changes span across the client, estimator, and service wrapper to support this new identifier. My review has identified a few critical issues: a potential UnboundLocalError in the client, a logic error in the prediction retry mechanism, and a configuration change that points to a local server instead of production. I've also suggested improving test coverage for the new model_id functionality. Please address these points to ensure the stability and correctness of the implementation.

)
elif isinstance(event, FitCompleteEvent):
train_set_uid = event.train_set_uid
model_id = event.model_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable model_id is assigned here within an elif block. However, it is not initialized before the loop (lines 474-511). If a FitCompleteEvent is not received from the server, the subsequent reference to model_id on line 519 will raise an UnboundLocalError. To prevent this, please initialize model_id to None before the try block, similar to how train_set_uid is initialized.

Comment on lines +1 to +9
# local testing
protocol: "http"
host: "localhost"
port: "8080"

## production
#protocol: "https"
#host: "api.priorlabs.ai"
#port: "443"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change switches the server configuration from production to a local testing environment. While this is useful for development, it should be reverted before merging to avoid pointing the client to a local server in a production release. Please ensure the production configuration is active in the final version of this pull request.

# # local testing
# protocol: "http"
# host: "localhost"
# port: "8080"

# production
protocol: "https"
host: "api.priorlabs.ai"
port: "443"

Comment on lines +684 to 685
train_set_uid = fitted_model.train_set_uid
params["train_set_uid"] = train_set_uid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When retrying a prediction after a UID error, a new model is fitted, but only the train_set_uid is updated in the params for the next attempt. The new model_id from fitted_model is ignored. This could lead to using a stale model_id in the retried prediction request. You should also update the model_id in the params dictionary.

Suggested change
train_set_uid = fitted_model.train_set_uid
params["train_set_uid"] = train_set_uid
train_set_uid = fitted_model.train_set_uid
params["train_set_uid"] = train_set_uid
params["model_id"] = fitted_model.model_id

train_set_uid2 = fitted_model2.train_set_uid

# The train_set_uid should be the same due to caching
self.assertEqual(train_set_uid1, train_set_uid2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The existing tests are correctly updated to handle the new return type of fit(). However, the new model_id functionality is not covered by any tests. Please consider adding new tests to verify that:

  1. The model_id is correctly passed to the predict endpoint.
  2. Caching for predictions works correctly when different model_ids are used for the same test data.

Copy link
Contributor

@safaricd safaricd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a high level, nothing major except of the two comments I left.


cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid)
return train_set_uid
return FittedModel(train_set_uid=train_set_uid, model_id=model_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break backwards compatibility for anyone using the train_set_uid?

"Automatically re-uploading the train set is not supported for thinking mode. Please call fit()."
)
train_set_uid = cls.fit(
fitted_model = cls.fit(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share more context on why do we re-fit it again in predict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants