Skip to content

ADD Sabac model#307

Open
michaelteboul wants to merge 3 commits intomainfrom
add_sabac
Open

ADD Sabac model#307
michaelteboul wants to merge 3 commits intomainfrom
add_sabac

Conversation

@michaelteboul
Copy link
Collaborator

Description of the goal of the PR

Description:

Changes this PR introduces (fill it before implementation)

  • Add Sabac model - Self Attention for BAsket Completion -.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @michaelteboul, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a sophisticated new model, SABAC, for basket completion. It utilizes a transformer-based architecture to effectively learn and represent basket embeddings, integrating both short-term and long-term user preferences. The model also accounts for various contextual factors such as item intercepts, price sensitivities, and store-specific effects, providing a comprehensive framework for predicting future purchases. A key aspect of this implementation is the custom handling of its internal transformer blocks for robust model saving and loading.

Highlights

  • New Model Introduction: Added the SABAC (Self Attention for BAsket Completion) model, a new attention-based model designed for basket completion tasks.
  • Transformer Architecture: Implemented a TransformerBlock class to handle multi-head self-attention and feed-forward networks, forming the core of the basket embedding mechanism.
  • Configurable Effects: The model supports configurable item intercepts, price effects, and store effects, allowing for more nuanced utility calculations.
  • Custom Serialization: Overrode save_model and load_model methods to correctly serialize and deserialize the custom TransformerBlock components, ensuring model persistence.
  • Vectorized Negative Sampling: Introduced a vectorized approach for generating negative samples, which can improve training efficiency.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • choice_learn/basket_models/sabac.py
    • Added the SABAC class, inheriting from BaseBasketModel, which implements the Self Attention for BAsket Completion model.
    • Included the TransformerBlock class, a nested component responsible for the attention mechanism within SABAC.
    • Implemented methods for model instantiation, basket embedding, utility computation (short-term, long-term, and combined), and loss calculation.
    • Provided custom save_model and load_model methods to correctly handle the serialization and deserialization of TransformerBlock instances.
Activity
  • The michaelteboul user created this pull request to introduce the new SABAC model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new SABAC model, which is an attention-based model for basket completion. It also includes a TransformerBlock class to support the attention mechanism. The changes involve implementing the model's initialization, parameter instantiation, basket embedding logic, utility computations, and loss calculation. The save_model and load_model methods have been overridden to handle the serialization and deserialization of the TransformerBlock objects. Overall, the code is well-structured and follows a logical flow for a deep learning model. However, there are several areas for improvement regarding clarity, consistency, and potential bugs, particularly in the compute_batch_loss method and the TransformerBlock initialization.

Comment on lines 681 to 693
epsilon = 0.0
loglikelihood = tf.reduce_sum(
tf.math.log(
tf.sigmoid(
tf.tile(
positive_samples_utility,
[1, self.n_negative_samples],
)
- negative_samples_utility
)
+ epsilon
),
) # Shape of loglikelihood: (1,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The epsilon = 0.0 is added to the log argument of tf.math.log. If epsilon is 0, it won't prevent log(0) if the tf.sigmoid output is 0. It should be a small positive value (e.g., 1e-8) to ensure numerical stability.

        epsilon = 1e-8
        loglikelihood = tf.reduce_sum(
            tf.math.log(
                tf.sigmoid(
                    tf.tile(
                        positive_samples_utility,
                        [1, self.n_negative_samples],
                    )
                    - negative_samples_utility
                ) + epsilon
            )
        )

"""
store_batch = tf.cast(store_batch, dtype=tf.int32)
price_batch = tf.cast(price_batch, dtype=tf.float32)
x_item = tf.gather(self.X, indices=item_batch) # Shape: (batch_size, None, d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment Shape: (batch_size, None, d) is incorrect. x_item will have shape (batch_size, d) if item_batch is (batch_size,) or (batch_size, num_items) if item_batch is (batch_size, num_items). The None dimension is misleading.

        x_item = tf.gather(self.X, indices=item_batch)  # Shape: (batch_size, d) or (batch_size, num_items, d)

Weight decay (L2 regularization) factor. Default is None (no weight decay).
momentum : float
Momentum factor for optimizers that support it. Default is 0.0.
item_intercept: bool, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter item_intercept is mentioned in the docstring but the actual parameter name in the __init__ method is intercept. This should be consistent.

Suggested change
item_intercept: bool, optional
intercept: bool, optional

Comment on lines 107 to 115
for val in latent_sizes.keys():
if val not in ["short_term", "long_term", "price"]:
raise ValueError(f"Unknown value for latent_sizes dict: {val}.")
if "short_term" not in latent_sizes:
latent_sizes["short_term"] = 10
if "long_term" not in latent_sizes:
latent_sizes["long_term"] = 10
if "price" not in latent_sizes:
latent_sizes["price"] = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for setting default latent_sizes values is a bit redundant. It would be cleaner to define the default dictionary directly in the __init__ signature with all expected keys, or to use dict.get with default values if the keys might be missing from the input latent_sizes.

Suggested change
for val in latent_sizes.keys():
if val not in ["short_term", "long_term", "price"]:
raise ValueError(f"Unknown value for latent_sizes dict: {val}.")
if "short_term" not in latent_sizes:
latent_sizes["short_term"] = 10
if "long_term" not in latent_sizes:
latent_sizes["long_term"] = 10
if "price" not in latent_sizes:
latent_sizes["price"] = 4
self.latent_sizes = {
"short_term": latent_sizes.get("short_term", 10),
"long_term": latent_sizes.get("long_term", 10),
"price": latent_sizes.get("price", 4),
}
for val in latent_sizes.keys():
if val not in self.latent_sizes:
raise ValueError(f"Unknown value for latent_sizes dict: {val}.")

Comment on lines 165 to 175
----------
X : tf.Variable
Item embedding matrix for short-term preferences, size (n_items, d).
V : tf.Variable
Item embedding matrix for long-term preferences, size (n_items, d_long).
U : tf.Variable
User embedding matrix for long-term preferences, size (n_users, d_long).
Wq : tf.Variable
Weight matrix for query transformation in attention mechanism, size (d, d).
Wk : tf.Variable
Weight matrix for key transformation in attention mechanism, size (d, d).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for instantiate is missing n_stores in the Parameters section and W_Q, W_K, W_V, W_O, W1, W2, b1, b2, gamma1, beta1, gamma2, beta2, S (if attention pooling) in the Variables section. These are all trainable weights that are part of the model's architecture.

        n_stores : int
            Number of unique stores in the dataset.

        Variables
        ----------
        X : tf.Variable
            Item embedding matrix for short-term preferences, size (n_items, d).
        V : tf.Variable
            Item embedding matrix for long-term preferences, size (n_items, d_long).
        U : tf.Variable
            User embedding matrix for long-term preferences, size (n_users, d_long).
        theta : tf.Variable, optional
            Store effects embedding matrix, size (n_stores, d).
        beta : tf.Variable, optional
            Item price sensitivity embedding matrix, size (n_items, latent_sizes["price"]).
        delta : tf.Variable, optional
            Store price sensitivity embedding matrix, size (n_stores, latent_sizes["price"]).
        alpha : tf.Variable, optional
            Item intercept vector, size (n_items,).
        CLS_token : tf.Variable, optional
            CLS token embedding, size (1, d).
        W_Q, W_K, W_V, W_O, W1, W2, b1, b2, gamma1, beta1, gamma2, beta2, S : tf.Variable, optional
            Weights and biases for TransformerBlocks and attention pooling.

# For attention pooling, we use cross-attention with 1 head instead of self-attention
self.num_heads = 1
self.head_dim = d_model
self.S = add_var((1, d_model), "W_Q")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If add_var is made a static method, the call to it needs to be updated to pass self.name as name_prefix and then append the returned variable to self._trainable_weights.

            self.S = self._add_var((1, d_model), "W_Q", name_prefix=self.name)
            self._trainable_weights.append(self.S)

Comment on lines +833 to +843
self.W_Q = add_var((d_model, d_model), "W_Q")
self.W1 = add_var((d_model, d_ffn), "ffn_W1")
self.b1 = add_var((d_ffn,), "ffn_b1", zeros=True)
self.gamma1 = add_var((d_model,), "ln1_gamma")
self.beta1 = add_var((d_model,), "ln1_beta", zeros=True)
self.gamma2 = add_var((d_model,), "ln2_gamma")
self.beta2 = add_var((d_model,), "ln2_beta", zeros=True)
self.W2 = add_var((d_ffn, d_model), "ffn_W2")
self.b2 = add_var((d_model,), "ffn_b2", zeros=True)
self.W_O = add_var((d_model, d_model), "W_O") if num_heads > 1 else None
self.W_K = add_var((d_model, d_model), "W_K")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If add_var is made a static method, the calls to it need to be updated to pass self.name as name_prefix and then append the returned variables to self._trainable_weights.

            self.W_Q = self._add_var((d_model, d_model), "W_Q", name_prefix=self.name)
            self._trainable_weights.append(self.W_Q)
            self.W1 = self._add_var((d_model, d_ffn), "ffn_W1", name_prefix=self.name)
            self._trainable_weights.append(self.W1)
            self.b1 = self._add_var((d_ffn,), "ffn_b1", zeros=True, name_prefix=self.name)
            self._trainable_weights.append(self.b1)
            self.gamma1 = self._add_var((d_model,), "ln1_gamma", name_prefix=self.name)
            self._trainable_weights.append(self.gamma1)
            self.beta1 = self._add_var((d_model,), "ln1_beta", zeros=True, name_prefix=self.name)
            self._trainable_weights.append(self.beta1)
            self.gamma2 = self._add_var((d_model,), "ln2_gamma", name_prefix=self.name)
            self._trainable_weights.append(self.gamma2)
            self.beta2 = self._add_var((d_model,), "ln2_beta", zeros=True, name_prefix=self.name)
            self._trainable_weights.append(self.beta2)
            self.W2 = self._add_var((d_ffn, d_model), "ffn_W2", name_prefix=self.name)
            self._trainable_weights.append(self.W2)
            self.b2 = self._add_var((d_model,), "ffn_b2", zeros=True, name_prefix=self.name)
            self._trainable_weights.append(self.b2)
            if num_heads > 1:
                self.W_O = self._add_var((d_model, d_model), "W_O", name_prefix=self.name)
                self._trainable_weights.append(self.W_O)
            else:
                self.W_O = None

Comment on lines +844 to +845
self.W_V = add_var((d_model, d_model), "W_V") if use_value_matrix else None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If add_var is made a static method, the calls to it need to be updated to pass self.name as name_prefix and then append the returned variables to self._trainable_weights.

        self.W_K = self._add_var((d_model, d_model), "W_K", name_prefix=self.name)
        self._trainable_weights.append(self.W_K)
        if use_value_matrix:
            self.W_V = self._add_var((d_model, d_model), "W_V", name_prefix=self.name)
            self._trainable_weights.append(self.W_V)
        else:
            self.W_V = None

q_len = tf.shape(q)[1]
v = tf.matmul(x, self.W_V) if self.use_value_matrix else x

def split_heads(tensor, seq_len=seq_len):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The seq_len parameter in split_heads is redundant. It can be directly inferred from tf.shape(tensor)[1] within the function.

        def split_heads(tensor):
            seq_len_local = tf.shape(tensor)[1]
            tensor = tf.reshape(tensor, (batch_size, seq_len_local, self.num_heads, self.head_dim))
            return tf.transpose(tensor, perm=[0, 2, 1, 3])

tensor = tf.reshape(tensor, (batch_size, seq_len, self.num_heads, self.head_dim))
return tf.transpose(tensor, perm=[0, 2, 1, 3])

q_h = split_heads(q, seq_len=q_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If seq_len is removed from split_heads signature, this call needs to be updated.

        q_h = split_heads(q)

@github-actions
Copy link
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.9
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2352789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 891–894, 934–935
   basic_attention_model.py89496%424, 427, 433, 440
   sabac.py3123120%3–967
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1903084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 790–801, 849
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 51, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2362360%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL5956116380% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 6m 16s ⏱️

@github-actions
Copy link
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.10
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2352789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 891–894, 934–935
   basic_attention_model.py89496%424, 427, 433, 440
   sabac.py3123120%3–967
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1903084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 790–801, 849
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 51, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py1241885%186, 341, 360, 364–380
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL5958118180% 

Tests Skipped Failures Errors Time
222 0 💤 1 ❌ 0 🔥 6m 52s ⏱️

@github-actions
Copy link
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.11
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2352789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 891–894, 934–935
   basic_attention_model.py89496%424, 427, 433, 440
   sabac.py3123120%3–967
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1903084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 790–801, 849
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 51, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL5958116580% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 7m 11s ⏱️

@github-actions
Copy link
Contributor

github-actions bot commented Feb 19, 2026

Coverage

Coverage Report for Python 3.12
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1482285%86–90, 92–96, 98–102, 106, 109, 131, 159, 308, 431–455
   base_basket_model.py2352789%111–112, 123, 141, 185, 255, 377, 485, 585–587, 676, 762, 772, 822–830, 891–894, 934–935
   basic_attention_model.py89496%424, 427, 433, 440
   sabac.py3123120%3–967
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1903084%74–77, 295–297, 407, 540–576, 636, 658–661, 700–705, 790–801, 849
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   bakery.py38392%47, 53, 61
   synthetic_dataset.py81693%62, 194–199, 247
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py854349%74, 126–130, 147–166, 176, 190–199, 211–232, 242
TOTAL5958116580% 

Tests Skipped Failures Errors Time
222 0 💤 0 ❌ 0 🔥 7m 42s ⏱️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant