Conversation
Summary of ChangesHello @michaelteboul, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a sophisticated new model, SABAC, for basket completion. It utilizes a transformer-based architecture to effectively learn and represent basket embeddings, integrating both short-term and long-term user preferences. The model also accounts for various contextual factors such as item intercepts, price sensitivities, and store-specific effects, providing a comprehensive framework for predicting future purchases. A key aspect of this implementation is the custom handling of its internal transformer blocks for robust model saving and loading. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a new SABAC model, which is an attention-based model for basket completion. It also includes a TransformerBlock class to support the attention mechanism. The changes involve implementing the model's initialization, parameter instantiation, basket embedding logic, utility computations, and loss calculation. The save_model and load_model methods have been overridden to handle the serialization and deserialization of the TransformerBlock objects. Overall, the code is well-structured and follows a logical flow for a deep learning model. However, there are several areas for improvement regarding clarity, consistency, and potential bugs, particularly in the compute_batch_loss method and the TransformerBlock initialization.
choice_learn/basket_models/sabac.py
Outdated
| epsilon = 0.0 | ||
| loglikelihood = tf.reduce_sum( | ||
| tf.math.log( | ||
| tf.sigmoid( | ||
| tf.tile( | ||
| positive_samples_utility, | ||
| [1, self.n_negative_samples], | ||
| ) | ||
| - negative_samples_utility | ||
| ) | ||
| + epsilon | ||
| ), | ||
| ) # Shape of loglikelihood: (1,)) |
There was a problem hiding this comment.
The epsilon = 0.0 is added to the log argument of tf.math.log. If epsilon is 0, it won't prevent log(0) if the tf.sigmoid output is 0. It should be a small positive value (e.g., 1e-8) to ensure numerical stability.
epsilon = 1e-8
loglikelihood = tf.reduce_sum(
tf.math.log(
tf.sigmoid(
tf.tile(
positive_samples_utility,
[1, self.n_negative_samples],
)
- negative_samples_utility
) + epsilon
)
)| """ | ||
| store_batch = tf.cast(store_batch, dtype=tf.int32) | ||
| price_batch = tf.cast(price_batch, dtype=tf.float32) | ||
| x_item = tf.gather(self.X, indices=item_batch) # Shape: (batch_size, None, d) |
There was a problem hiding this comment.
The comment Shape: (batch_size, None, d) is incorrect. x_item will have shape (batch_size, d) if item_batch is (batch_size,) or (batch_size, num_items) if item_batch is (batch_size, num_items). The None dimension is misleading.
x_item = tf.gather(self.X, indices=item_batch) # Shape: (batch_size, d) or (batch_size, num_items, d)| Weight decay (L2 regularization) factor. Default is None (no weight decay). | ||
| momentum : float | ||
| Momentum factor for optimizers that support it. Default is 0.0. | ||
| item_intercept: bool, optional |
choice_learn/basket_models/sabac.py
Outdated
| for val in latent_sizes.keys(): | ||
| if val not in ["short_term", "long_term", "price"]: | ||
| raise ValueError(f"Unknown value for latent_sizes dict: {val}.") | ||
| if "short_term" not in latent_sizes: | ||
| latent_sizes["short_term"] = 10 | ||
| if "long_term" not in latent_sizes: | ||
| latent_sizes["long_term"] = 10 | ||
| if "price" not in latent_sizes: | ||
| latent_sizes["price"] = 4 |
There was a problem hiding this comment.
The logic for setting default latent_sizes values is a bit redundant. It would be cleaner to define the default dictionary directly in the __init__ signature with all expected keys, or to use dict.get with default values if the keys might be missing from the input latent_sizes.
| for val in latent_sizes.keys(): | |
| if val not in ["short_term", "long_term", "price"]: | |
| raise ValueError(f"Unknown value for latent_sizes dict: {val}.") | |
| if "short_term" not in latent_sizes: | |
| latent_sizes["short_term"] = 10 | |
| if "long_term" not in latent_sizes: | |
| latent_sizes["long_term"] = 10 | |
| if "price" not in latent_sizes: | |
| latent_sizes["price"] = 4 | |
| self.latent_sizes = { | |
| "short_term": latent_sizes.get("short_term", 10), | |
| "long_term": latent_sizes.get("long_term", 10), | |
| "price": latent_sizes.get("price", 4), | |
| } | |
| for val in latent_sizes.keys(): | |
| if val not in self.latent_sizes: | |
| raise ValueError(f"Unknown value for latent_sizes dict: {val}.") |
choice_learn/basket_models/sabac.py
Outdated
| ---------- | ||
| X : tf.Variable | ||
| Item embedding matrix for short-term preferences, size (n_items, d). | ||
| V : tf.Variable | ||
| Item embedding matrix for long-term preferences, size (n_items, d_long). | ||
| U : tf.Variable | ||
| User embedding matrix for long-term preferences, size (n_users, d_long). | ||
| Wq : tf.Variable | ||
| Weight matrix for query transformation in attention mechanism, size (d, d). | ||
| Wk : tf.Variable | ||
| Weight matrix for key transformation in attention mechanism, size (d, d). |
There was a problem hiding this comment.
The docstring for instantiate is missing n_stores in the Parameters section and W_Q, W_K, W_V, W_O, W1, W2, b1, b2, gamma1, beta1, gamma2, beta2, S (if attention pooling) in the Variables section. These are all trainable weights that are part of the model's architecture.
n_stores : int
Number of unique stores in the dataset.
Variables
----------
X : tf.Variable
Item embedding matrix for short-term preferences, size (n_items, d).
V : tf.Variable
Item embedding matrix for long-term preferences, size (n_items, d_long).
U : tf.Variable
User embedding matrix for long-term preferences, size (n_users, d_long).
theta : tf.Variable, optional
Store effects embedding matrix, size (n_stores, d).
beta : tf.Variable, optional
Item price sensitivity embedding matrix, size (n_items, latent_sizes["price"]).
delta : tf.Variable, optional
Store price sensitivity embedding matrix, size (n_stores, latent_sizes["price"]).
alpha : tf.Variable, optional
Item intercept vector, size (n_items,).
CLS_token : tf.Variable, optional
CLS token embedding, size (1, d).
W_Q, W_K, W_V, W_O, W1, W2, b1, b2, gamma1, beta1, gamma2, beta2, S : tf.Variable, optional
Weights and biases for TransformerBlocks and attention pooling.| # For attention pooling, we use cross-attention with 1 head instead of self-attention | ||
| self.num_heads = 1 | ||
| self.head_dim = d_model | ||
| self.S = add_var((1, d_model), "W_Q") |
There was a problem hiding this comment.
| self.W_Q = add_var((d_model, d_model), "W_Q") | ||
| self.W1 = add_var((d_model, d_ffn), "ffn_W1") | ||
| self.b1 = add_var((d_ffn,), "ffn_b1", zeros=True) | ||
| self.gamma1 = add_var((d_model,), "ln1_gamma") | ||
| self.beta1 = add_var((d_model,), "ln1_beta", zeros=True) | ||
| self.gamma2 = add_var((d_model,), "ln2_gamma") | ||
| self.beta2 = add_var((d_model,), "ln2_beta", zeros=True) | ||
| self.W2 = add_var((d_ffn, d_model), "ffn_W2") | ||
| self.b2 = add_var((d_model,), "ffn_b2", zeros=True) | ||
| self.W_O = add_var((d_model, d_model), "W_O") if num_heads > 1 else None | ||
| self.W_K = add_var((d_model, d_model), "W_K") |
There was a problem hiding this comment.
If add_var is made a static method, the calls to it need to be updated to pass self.name as name_prefix and then append the returned variables to self._trainable_weights.
self.W_Q = self._add_var((d_model, d_model), "W_Q", name_prefix=self.name)
self._trainable_weights.append(self.W_Q)
self.W1 = self._add_var((d_model, d_ffn), "ffn_W1", name_prefix=self.name)
self._trainable_weights.append(self.W1)
self.b1 = self._add_var((d_ffn,), "ffn_b1", zeros=True, name_prefix=self.name)
self._trainable_weights.append(self.b1)
self.gamma1 = self._add_var((d_model,), "ln1_gamma", name_prefix=self.name)
self._trainable_weights.append(self.gamma1)
self.beta1 = self._add_var((d_model,), "ln1_beta", zeros=True, name_prefix=self.name)
self._trainable_weights.append(self.beta1)
self.gamma2 = self._add_var((d_model,), "ln2_gamma", name_prefix=self.name)
self._trainable_weights.append(self.gamma2)
self.beta2 = self._add_var((d_model,), "ln2_beta", zeros=True, name_prefix=self.name)
self._trainable_weights.append(self.beta2)
self.W2 = self._add_var((d_ffn, d_model), "ffn_W2", name_prefix=self.name)
self._trainable_weights.append(self.W2)
self.b2 = self._add_var((d_model,), "ffn_b2", zeros=True, name_prefix=self.name)
self._trainable_weights.append(self.b2)
if num_heads > 1:
self.W_O = self._add_var((d_model, d_model), "W_O", name_prefix=self.name)
self._trainable_weights.append(self.W_O)
else:
self.W_O = None| self.W_V = add_var((d_model, d_model), "W_V") if use_value_matrix else None | ||
|
|
There was a problem hiding this comment.
If add_var is made a static method, the calls to it need to be updated to pass self.name as name_prefix and then append the returned variables to self._trainable_weights.
self.W_K = self._add_var((d_model, d_model), "W_K", name_prefix=self.name)
self._trainable_weights.append(self.W_K)
if use_value_matrix:
self.W_V = self._add_var((d_model, d_model), "W_V", name_prefix=self.name)
self._trainable_weights.append(self.W_V)
else:
self.W_V = None| q_len = tf.shape(q)[1] | ||
| v = tf.matmul(x, self.W_V) if self.use_value_matrix else x | ||
|
|
||
| def split_heads(tensor, seq_len=seq_len): |
There was a problem hiding this comment.
The seq_len parameter in split_heads is redundant. It can be directly inferred from tf.shape(tensor)[1] within the function.
def split_heads(tensor):
seq_len_local = tf.shape(tensor)[1]
tensor = tf.reshape(tensor, (batch_size, seq_len_local, self.num_heads, self.head_dim))
return tf.transpose(tensor, perm=[0, 2, 1, 3])| tensor = tf.reshape(tensor, (batch_size, seq_len, self.num_heads, self.head_dim)) | ||
| return tf.transpose(tensor, perm=[0, 2, 1, 3]) | ||
|
|
||
| q_h = split_heads(q, seq_len=q_len) |
Coverage Report for Python 3.10
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Coverage Report for Python 3.11
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Coverage Report for Python 3.12
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Description of the goal of the PR
Description:
Changes this PR introduces (fill it before implementation)