@@ -105,6 +105,11 @@ class SWIGResources implements AutoCloseable {
105105 */
106106 private Integer boosterNumFeatures ;
107107
108+ /**
109+ * Number of classes in the trained LGBM.
110+ */
111+ private Integer boosterNumClasses = null ;
112+
108113 /**
109114 * Names of features in the trained LightGBM boosting model.
110115 * Whilst not a swig resource, it is automatically retrieved during model loading,
@@ -252,19 +257,21 @@ private void initBoosterFastContributionsHandle(final String LightGBMParameters)
252257 * Assumes the model was already loaded from file.
253258 * Initializes the remaining SWIG resources needed to use the model.
254259 *
260+ * The size of {@link #swigOutContributionsPtr} is computed accoring to
261+ * https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow
262+ *
255263 * @throws LightGBMException in case there's an error in the C++ core library.
256264 */
257265 private void initAuxiliaryModelResources () throws LightGBMException {
258-
259- this .boosterNumFeatures = computeBoosterNumFeaturesFromModel ();
266+ computeBoosterNumFeaturesFromModel ();
260267 logger .debug ("Loaded LightGBM Model has {} features." , this .boosterNumFeatures );
261268
262269 this .boosterFeatureNames = computeBoosterFeatureNamesFromModel ();
263270
264271 this .swigOutLengthInt64Ptr = lightgbmlibJNI .new_int64_tp ();
265272 this .swigInstancePtr = lightgbmlibJNI .new_doubleArray (getBoosterNumFeatures ());
266- this .swigOutScoresPtr = lightgbmlibJNI .new_doubleArray (BINARY_LGBM_NUM_CLASSES );
267- this .swigOutContributionsPtr = lightgbmlibJNI .new_doubleArray (this .boosterNumFeatures + 1 );
273+ this .swigOutScoresPtr = lightgbmlibJNI .new_doubleArray (this . boosterNumClasses );
274+ this .swigOutContributionsPtr = lightgbmlibJNI .new_doubleArray (( long ) this .boosterNumClasses * ( this . boosterNumFeatures + 1 ) );
268275 }
269276
270277 /**
@@ -302,6 +309,14 @@ private void releaseInitializedSWIGResources() throws LightGBMException {
302309 lightgbmlibJNI .delete_intp (this .swigOutIntPtr );
303310 this .swigOutIntPtr = null ;
304311 }
312+ if (this .boosterNumFeatures != null ) {
313+ lightgbmlibJNI .delete_intp (this .boosterNumFeatures );
314+ this .boosterNumFeatures = null ;
315+ }
316+ if (this .boosterNumClasses != null ) {
317+ lightgbmlibJNI .delete_intp (this .boosterNumClasses );
318+ this .boosterNumClasses = null ;
319+ }
305320 if (this .swigOutContributionsPtr != null ) {
306321 lightgbmlibJNI .delete_doubleArray (this .swigOutContributionsPtr );
307322 this .swigOutContributionsPtr = null ;
@@ -373,17 +388,31 @@ public String[] getBoosterFeatureNames() {
373388 * Computes the number of features in the model and returns it.
374389 *
375390 * @throws LightGBMException when there is a LightGBM C++ error.
376- * @returns int with the number of Booster features.
377391 */
378- private Integer computeBoosterNumFeaturesFromModel () throws LightGBMException {
379-
380- final int returnCodeLGBM = lightgbmlibJNI .LGBM_BoosterGetNumFeature (
392+ private void computeBoosterNumFeaturesFromModel () throws LightGBMException {
393+ final int returnCodeNumFeatsLGBM = lightgbmlibJNI .LGBM_BoosterGetNumFeature (
381394 this .swigBoosterHandle ,
382395 this .swigOutIntPtr );
383- if (returnCodeLGBM == -1 )
396+ if (returnCodeNumFeatsLGBM == -1 )
384397 throw new LightGBMException ();
385398
386- return lightgbmlibJNI .intp_value (this .swigOutIntPtr );
399+
400+ if (this .boosterNumFeatures != null ) {
401+ lightgbmlibJNI .delete_intp (this .boosterNumFeatures );
402+ this .boosterNumFeatures = null ;
403+ }
404+ this .boosterNumFeatures = lightgbmlibJNI .intp_value (this .swigOutIntPtr );
405+
406+ final int returnCodeNumClassesLGBM = lightgbmlibJNI .LGBM_BoosterGetNumClasses (
407+ this .swigBoosterHandle ,
408+ this .swigOutIntPtr );
409+ if (returnCodeNumClassesLGBM == -1 )
410+ throw new LightGBMException ();
411+ if (this .boosterNumClasses != null ) {
412+ lightgbmlibJNI .delete_intp (this .boosterNumClasses );
413+ this .boosterNumClasses = null ;
414+ }
415+ this .boosterNumClasses = lightgbmlibJNI .intp_value (this .swigOutIntPtr );
387416 }
388417
389418 /**
0 commit comments